diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 5e7d6b0..0b6207f 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -189,3 +189,53 @@ def test_decagon_layer_04(): assert len(out_d_layer) == 1 assert torch.all(out_d_layer[0] == out_multi_dgca) + + +def test_decagon_layer_05(): + # check if it is equivalent to MultiDGCA, as it should be + # this time for two relations, same edge type + + d = Data() + d.add_node_type('Dummy', 100) + d.add_relation_type('Dummy Relation 1', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + d.add_relation_type('Dummy Relation 2', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(d) + + multi_dgca = SparseMultiDGCA([100, 100], 32, + [r.adjacency_matrix for r in d.relation_types[0, 0]], + keep_prob=1., activation=lambda x: x) + + d_layer = DecagonLayer(d, in_layer, output_dim=32, + keep_prob=1., rel_activation=lambda x: x, + layer_activation=lambda x: x) + + assert all([ + isinstance(dgca, DropoutGraphConvActivation) \ + for dgca in d_layer.next_layer_repr[0][0][0] + ]) + + weight = [ dgca.graph_conv.weight \ + for dgca in d_layer.next_layer_repr[0][0][0] ] + assert all([ + isinstance(w, torch.Tensor) \ + for w in weight + ]) + + assert len(multi_dgca.sparse_dgca) == 2 + for i in range(2): + assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation) + + for i in range(2): + multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i] + + out_d_layer = d_layer() + x = in_layer() + out_multi_dgca = multi_dgca([ x[0], x[0] ]) + + assert isinstance(out_d_layer, list) + assert len(out_d_layer) == 1 + + assert torch.all(out_d_layer[0] == out_multi_dgca)