From 02bbfc4958de5ad87fd34f1500502d06246ded7c Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 29 May 2020 10:38:08 +0200 Subject: [PATCH] Add test_decagon_layer_04(). --- src/decagon_pytorch/layer.py | 7 +++-- tests/decagon_pytorch/test_layer.py | 45 +++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index b1f5d70..69ca74c 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -21,7 +21,7 @@ import torch -from .convolve import SparseDropoutGraphConvActivation +from .convolve import DropoutGraphConvActivation from .data import Data from typing import List, \ Union, \ @@ -125,7 +125,7 @@ class DecagonLayer(Layer): for (nt_row, nt_col), relation_types in self.data.relation_types.items(): for rel in relation_types: - conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col], + conv = DropoutGraphConvActivation(self.input_dim[nt_col], self.output_dim[nt_row], rel.adjacency_matrix, self.keep_prob, self.rel_activation) self.next_layer_repr[nt_row].append((conv, nt_col)) @@ -133,7 +133,7 @@ class DecagonLayer(Layer): if nt_row == nt_col: continue - conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], + conv = DropoutGraphConvActivation(self.input_dim[nt_row], self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), self.keep_prob, self.rel_activation) self.next_layer_repr[nt_col].append((conv, nt_row)) @@ -149,6 +149,7 @@ class DecagonLayer(Layer): self.next_layer_repr[i] ] next_layer_repr[i] = sum(next_layer_repr[i]) + next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1) print('next_layer_repr:', next_layer_repr) # next_layer_repr = list(map(sum, next_layer_repr)) diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index cfb5063..cc0e5b7 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -4,7 +4,9 @@ from decagon_pytorch.layer import InputLayer, \ from decagon_pytorch.data import Data import torch import pytest -from decagon_pytorch.convolve import SparseDropoutGraphConvActivation +from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \ + SparseMultiDGCA, \ + DropoutGraphConvActivation def _some_data(): @@ -136,9 +138,9 @@ def test_decagon_layer_03(): assert len(d_layer.next_layer_repr) == 2 assert len(d_layer.next_layer_repr[0]) == 2 assert len(d_layer.next_layer_repr[1]) == 4 - assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation), + assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), d_layer.next_layer_repr[0])) - assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation), + assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), d_layer.next_layer_repr[1])) assert all(map(lambda a: a[0].output_dim == 32, d_layer.next_layer_repr[0])) @@ -147,7 +149,38 @@ def test_decagon_layer_03(): def test_decagon_layer_04(): - d = _some_data_with_interactions() + # check if it is equivalent to MultiDGCA, as it should be + + d = Data() + d.add_node_type('Dummy', 100) + d.add_relation_type('Dummy Relation', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + in_layer = OneHotInputLayer(d) - d_layer = DecagonLayer(d, in_layer, output_dim=32) - _ = d_layer() + + multi_dgca = SparseMultiDGCA([10], 32, + [r.adjacency_matrix for r in d.relation_types[0, 0]], + keep_prob=1., activation=lambda x: x) + + d_layer = DecagonLayer(d, in_layer, output_dim=32, + keep_prob=1., rel_activation=lambda x: x, + layer_activation=lambda x: x) + + assert isinstance(d_layer.next_layer_repr[0][0][0], + DropoutGraphConvActivation) + + weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight + assert isinstance(weight, torch.Tensor) + + assert len(multi_dgca.sparse_dgca) == 1 + assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation) + + multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight + + out_d_layer = d_layer() + out_multi_dgca = multi_dgca(in_layer()) + + assert isinstance(out_d_layer, list) + assert len(out_d_layer) == 1 + + assert torch.all(out_d_layer[0] == out_multi_dgca)