diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index c781e6b..3b116e7 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -44,6 +44,7 @@ class SparseMultiDGCA(torch.nn.Module): activation=torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) + self.output_dim = output_dim self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] def forward(self, x): @@ -92,6 +93,7 @@ class MultiDGCA(torch.nn.Module): activation=torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) + self.output_dim = output_dim self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] def forward(self, x): diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 74998d0..1e2c538 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -180,7 +180,9 @@ def test_graph_conv(): def setup_function(fun): - if fun == test_dropout_graph_conv_activation: + if fun == test_dropout_graph_conv_activation or \ + fun == test_multi_dgca: + print('Disabling dropout for testing...') setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ decagon_pytorch.convolve.dropout_sparse @@ -189,7 +191,9 @@ def setup_function(fun): def teardown_function(fun): - if fun == test_dropout_graph_conv_activation: + print('Re-enabling dropout...') + if fun == test_dropout_graph_conv_activation or \ + fun == test_multi_dgca: decagon_pytorch.convolve.dropout, \ decagon_pytorch.convolve.dropout_sparse = \ setup_function.old_dropout @@ -213,3 +217,26 @@ def test_dropout_graph_conv_activation(): nonzero = (latent_dense != 0) & (latent_sparse != 0) assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) + + +def test_multi_dgca(): + keep_prob = .5 + + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + + latent_sparse = torch.tensor(latent).to_sparse() + latent = torch.tensor(latent) + + adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ] + adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ] + + multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob) + multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) + + # torch.random.manual_seed(0) + latent_sparse = multi_sparse(latent_sparse) + # torch.random.manual_seed(0) + latent = multi(latent) + + assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001)