| 
				
				
					
				
				
				 | 
			
			 | 
			@@ -4,7 +4,9 @@ from decagon_pytorch.layer import InputLayer, \ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from decagon_pytorch.data import Data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import torch
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import pytest
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from decagon_pytorch.convolve import SparseDropoutGraphConvActivation
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    SparseMultiDGCA, \
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    DropoutGraphConvActivation
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def _some_data():
 | 
		
		
	
	
		
			
				| 
				
					
				
				
					
				
				
				 | 
			
			 | 
			@@ -136,9 +138,9 @@ def test_decagon_layer_03(): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(d_layer.next_layer_repr) == 2
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(d_layer.next_layer_repr[0]) == 2
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(d_layer.next_layer_repr[1]) == 4
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        d_layer.next_layer_repr[0]))
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        d_layer.next_layer_repr[1]))
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert all(map(lambda a: a[0].output_dim == 32,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        d_layer.next_layer_repr[0]))
 | 
		
		
	
	
		
			
				| 
				
				
				
					
				
				 | 
			
			 | 
			@@ -147,7 +149,38 @@ def test_decagon_layer_03(): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def test_decagon_layer_04():
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d = _some_data_with_interactions()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    # check if it is equivalent to MultiDGCA, as it should be
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d = Data()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d.add_node_type('Dummy', 100)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d.add_relation_type('Dummy Relation', 0, 0,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    in_layer = OneHotInputLayer(d)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d_layer = DecagonLayer(d, in_layer, output_dim=32)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    _ = d_layer()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    multi_dgca = SparseMultiDGCA([10], 32,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        [r.adjacency_matrix for r in d.relation_types[0, 0]],
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        keep_prob=1., activation=lambda x: x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    d_layer = DecagonLayer(d, in_layer, output_dim=32,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        keep_prob=1., rel_activation=lambda x: x,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        layer_activation=lambda x: x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert isinstance(d_layer.next_layer_repr[0][0][0],
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        DropoutGraphConvActivation)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert isinstance(weight, torch.Tensor)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(multi_dgca.sparse_dgca) == 1
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    out_d_layer = d_layer()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    out_multi_dgca = multi_dgca(in_layer())
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert isinstance(out_d_layer, list)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(out_d_layer) == 1
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert torch.all(out_d_layer[0] == out_multi_dgca)
 |