|  |  | @@ -1,10 +1,47 @@ | 
		
	
		
			
			|  |  |  | from icosagon.fastconv import _sparse_diag_cat, \ | 
		
	
		
			
			|  |  |  | _cat, \ | 
		
	
		
			
			|  |  |  | FastGraphConv | 
		
	
		
			
			|  |  |  | FastGraphConv, \ | 
		
	
		
			
			|  |  |  | FastConvLayer | 
		
	
		
			
			|  |  |  | from icosagon.data import _equal | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | import pdb | 
		
	
		
			
			|  |  |  | import time | 
		
	
		
			
			|  |  |  | from icosagon.data import Data | 
		
	
		
			
			|  |  |  | from icosagon.input import OneHotInputLayer | 
		
	
		
			
			|  |  |  | from icosagon.convlayer import DecagonLayer | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _make_symmetric(x: torch.Tensor): | 
		
	
		
			
			|  |  |  | x = (x + x.transpose(0, 1)) / 2 | 
		
	
		
			
			|  |  |  | return x | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _symmetric_random(n_rows, n_columns): | 
		
	
		
			
			|  |  |  | return _make_symmetric(torch.rand((n_rows, n_columns), | 
		
	
		
			
			|  |  |  | dtype=torch.float32).round()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _some_data_with_interactions(): | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Gene', 1000) | 
		
	
		
			
			|  |  |  | d.add_node_type('Drug', 100) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Target', | 
		
	
		
			
			|  |  |  | torch.rand((100, 1000), dtype=torch.float32).round()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Gene-Gene', 0, 0, True) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Interaction', | 
		
	
		
			
			|  |  |  | _symmetric_random(1000, 1000)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Drug-Drug', 1, 1, True) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Side Effect: Nausea', | 
		
	
		
			
			|  |  |  | _symmetric_random(100, 100)) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Side Effect: Infertility', | 
		
	
		
			
			|  |  |  | _symmetric_random(100, 100)) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Side Effect: Death', | 
		
	
		
			
			|  |  |  | _symmetric_random(100, 100)) | 
		
	
		
			
			|  |  |  | return d | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_sparse_diag_cat_01(): | 
		
	
	
		
			
				|  |  | @@ -86,3 +123,51 @@ def test_fast_graph_conv_02(): | 
		
	
		
			
			|  |  |  | t = time.time() | 
		
	
		
			
			|  |  |  | _ = fgc(in_repr) | 
		
	
		
			
			|  |  |  | print('FGC forward pass took:', time.time() - t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_fast_graph_conv_03(): | 
		
	
		
			
			|  |  |  | adj_mat = [ | 
		
	
		
			
			|  |  |  | [ 0, 0, 1, 0, 1 ], | 
		
	
		
			
			|  |  |  | [ 0, 1, 0, 1, 0 ], | 
		
	
		
			
			|  |  |  | [ 1, 0, 1, 0, 0 ] | 
		
	
		
			
			|  |  |  | ] | 
		
	
		
			
			|  |  |  | in_repr = torch.rand(5, 32) | 
		
	
		
			
			|  |  |  | adj_mat = torch.tensor(adj_mat, dtype=torch.float32) | 
		
	
		
			
			|  |  |  | fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ]) | 
		
	
		
			
			|  |  |  | out_repr = fgc(in_repr) | 
		
	
		
			
			|  |  |  | assert out_repr.shape == (1, 3, 64) | 
		
	
		
			
			|  |  |  | assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_fast_graph_conv_04(): | 
		
	
		
			
			|  |  |  | adj_mat = [ | 
		
	
		
			
			|  |  |  | [ 0, 0, 1, 0, 1 ], | 
		
	
		
			
			|  |  |  | [ 0, 1, 0, 1, 0 ], | 
		
	
		
			
			|  |  |  | [ 1, 0, 1, 0, 0 ] | 
		
	
		
			
			|  |  |  | ] | 
		
	
		
			
			|  |  |  | in_repr = torch.rand(5, 32) | 
		
	
		
			
			|  |  |  | adj_mat = torch.tensor(adj_mat, dtype=torch.float32) | 
		
	
		
			
			|  |  |  | fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ]) | 
		
	
		
			
			|  |  |  | out_repr = fgc(in_repr) | 
		
	
		
			
			|  |  |  | assert out_repr.shape == (2, 3, 64) | 
		
	
		
			
			|  |  |  | adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2) | 
		
	
		
			
			|  |  |  | adj_mat_1[0:3, 0:5] = adj_mat | 
		
	
		
			
			|  |  |  | adj_mat_1[3:6, 5:10] = adj_mat | 
		
	
		
			
			|  |  |  | res = torch.mm(in_repr, fgc.weights) | 
		
	
		
			
			|  |  |  | res = torch.split(res, res.shape[1] // 2, dim=1) | 
		
	
		
			
			|  |  |  | res = torch.cat(res) | 
		
	
		
			
			|  |  |  | res = torch.mm(adj_mat_1, res) | 
		
	
		
			
			|  |  |  | assert (res.view(2, 3, 64) == out_repr).all() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_fast_conv_layer_01(): | 
		
	
		
			
			|  |  |  | d = _some_data_with_interactions() | 
		
	
		
			
			|  |  |  | in_layer = OneHotInputLayer(d) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | 
		
	
		
			
			|  |  |  | seq_1 = torch.nn.Sequential(in_layer, d_layer) | 
		
	
		
			
			|  |  |  | out_repr_1 = seq_1(None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | 
		
	
		
			
			|  |  |  | seq_2 = torch.nn.Sequential(in_layer, conv_layer) | 
		
	
		
			
			|  |  |  | out_repr_2 = seq_2(None) |