| @@ -189,3 +189,53 @@ def test_decagon_layer_04(): | |||||
| assert len(out_d_layer) == 1 | assert len(out_d_layer) == 1 | ||||
| assert torch.all(out_d_layer[0] == out_multi_dgca) | assert torch.all(out_d_layer[0] == out_multi_dgca) | ||||
| def test_decagon_layer_05(): | |||||
| # check if it is equivalent to MultiDGCA, as it should be | |||||
| # this time for two relations, same edge type | |||||
| d = Data() | |||||
| d.add_node_type('Dummy', 100) | |||||
| d.add_relation_type('Dummy Relation 1', 0, 0, | |||||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||||
| d.add_relation_type('Dummy Relation 2', 0, 0, | |||||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||||
| in_layer = OneHotInputLayer(d) | |||||
| multi_dgca = SparseMultiDGCA([100, 100], 32, | |||||
| [r.adjacency_matrix for r in d.relation_types[0, 0]], | |||||
| keep_prob=1., activation=lambda x: x) | |||||
| d_layer = DecagonLayer(d, in_layer, output_dim=32, | |||||
| keep_prob=1., rel_activation=lambda x: x, | |||||
| layer_activation=lambda x: x) | |||||
| assert all([ | |||||
| isinstance(dgca, DropoutGraphConvActivation) \ | |||||
| for dgca in d_layer.next_layer_repr[0][0][0] | |||||
| ]) | |||||
| weight = [ dgca.graph_conv.weight \ | |||||
| for dgca in d_layer.next_layer_repr[0][0][0] ] | |||||
| assert all([ | |||||
| isinstance(w, torch.Tensor) \ | |||||
| for w in weight | |||||
| ]) | |||||
| assert len(multi_dgca.sparse_dgca) == 2 | |||||
| for i in range(2): | |||||
| assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation) | |||||
| for i in range(2): | |||||
| multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i] | |||||
| out_d_layer = d_layer() | |||||
| x = in_layer() | |||||
| out_multi_dgca = multi_dgca([ x[0], x[0] ]) | |||||
| assert isinstance(out_d_layer, list) | |||||
| assert len(out_d_layer) == 1 | |||||
| assert torch.all(out_d_layer[0] == out_multi_dgca) | |||||