From a305dc70aa8118fe654c34b836f451dba99e86a2 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sun, 7 Jun 2020 13:40:10 +0200 Subject: [PATCH] Add tests for convlayer. --- src/decagon_pytorch/convolve/universal.py | 2 +- src/icosagon/convlayer.py | 11 +- src/icosagon/input.py | 5 +- tests/icosagon/test_convlayer.py | 167 ++++++++++++++++++++++ 4 files changed, 179 insertions(+), 6 deletions(-) create mode 100644 tests/icosagon/test_convlayer.py diff --git a/src/decagon_pytorch/convolve/universal.py b/src/decagon_pytorch/convolve/universal.py index f39d1a8..b266448 100644 --- a/src/decagon_pytorch/convolve/universal.py +++ b/src/decagon_pytorch/convolve/universal.py @@ -73,7 +73,7 @@ class MultiDGCA(torch.nn.Module): raise ValueError('input_dim must have the same length as adjacency_matrices') self.dgca = [] for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): - self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) + self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if not isinstance(x, list): diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index aef27ea..88f15b8 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -30,8 +30,11 @@ class DecagonLayer(torch.nn.Module): if not isinstance(input_dim, list): raise ValueError('input_dim must be a list') + if not output_dim: + raise ValueError('output_dim must be specified') + if not isinstance(output_dim, list): - raise ValueError('output_dim must be a list') + output_dim = [output_dim] * len(data.node_types) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise ValueError('data must be of type Data or PreparedData') @@ -87,8 +90,8 @@ class DecagonLayer(torch.nn.Module): for conv in convolutions.convolutions ] repr_ = sum(repr_) repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1) - next_layer_repr[i].append(repr_) - next_layer_repr[i] = sum(next_layer_repr[i]) - next_layer_repr[i] = self.layer_activation(next_layer_repr[i]) + next_layer_repr[node_type_row].append(repr_) + next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) + next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) return next_layer_repr diff --git a/src/icosagon/input.py b/src/icosagon/input.py index c0b2672..4f7dd73 100644 --- a/src/icosagon/input.py +++ b/src/icosagon/input.py @@ -11,9 +11,12 @@ from .data import Data class InputLayer(torch.nn.Module): - def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, **kwargs) -> None: + def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, + **kwargs) -> None: + output_dim = output_dim or \ list(map(lambda a: a.count, data.node_types)) + if not isinstance(output_dim, list): output_dim = [output_dim,] * len(data.node_types) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py new file mode 100644 index 0000000..82b7f56 --- /dev/null +++ b/tests/icosagon/test_convlayer.py @@ -0,0 +1,167 @@ +from icosagon.input import InputLayer, \ + OneHotInputLayer +from icosagon.convlayer import DecagonLayer, \ + Convolutions +from icosagon.data import Data +import torch +import pytest +from icosagon.convolve import DropoutGraphConvActivation +from decagon_pytorch.convolve import MultiDGCA +import decagon_pytorch.convolve + + +def _some_data_with_interactions(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, + torch.rand((100, 1000), dtype=torch.float32).round()) + d.add_relation_type('Interaction', 0, 0, + torch.rand((1000, 1000), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Nausea', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Infertility', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Death', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + return d + + +def test_decagon_layer_01(): + d = _some_data_with_interactions() + in_layer = InputLayer(d) + d_layer = DecagonLayer(in_layer.output_dim, 32, d) + seq = torch.nn.Sequential(in_layer, d_layer) + _ = seq(None) # dummy call + + +def test_decagon_layer_02(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(in_layer.output_dim, 32, d) + seq = torch.nn.Sequential(in_layer, d_layer) + _ = seq(None) # dummy call + + +def test_decagon_layer_03(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(in_layer.output_dim, 32, d) + + assert d_layer.input_dim == [ 1000, 100 ] + assert d_layer.output_dim == [ 32, 32 ] + assert d_layer.data == d + assert d_layer.keep_prob == 1. + assert d_layer.rel_activation(0.5) == 0.5 + x = torch.tensor([-1, 0, 0.5, 1]) + assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all() + + assert not d_layer.is_sparse + assert len(d_layer.next_layer_repr) == 2 + + for i in range(2): + assert len(d_layer.next_layer_repr[i]) == 2 + assert isinstance(d_layer.next_layer_repr[i], list) + assert isinstance(d_layer.next_layer_repr[i][0], Convolutions) + assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int) + assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list) + assert all([ + isinstance(dgca, DropoutGraphConvActivation) \ + for dgca in d_layer.next_layer_repr[i][0].convolutions + ]) + assert all([ + dgca.output_dim == 32 \ + for dgca in d_layer.next_layer_repr[i][0].convolutions + ]) + + +def test_decagon_layer_04(): + # check if it is equivalent to MultiDGCA, as it should be + + d = Data() + d.add_node_type('Dummy', 100) + d.add_relation_type('Dummy Relation', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(d) + + multi_dgca = MultiDGCA([10], 32, + [r.adjacency_matrix for r in d.relation_types[0][0]], + keep_prob=1., activation=lambda x: x) + + d_layer = DecagonLayer(in_layer.output_dim, 32, d, + keep_prob=1., rel_activation=lambda x: x, + layer_activation=lambda x: x) + + assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0], + DropoutGraphConvActivation) + + weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight + assert isinstance(weight, torch.Tensor) + + assert len(multi_dgca.dgca) == 1 + assert isinstance(multi_dgca.dgca[0], + decagon_pytorch.convolve.DropoutGraphConvActivation) + + multi_dgca.dgca[0].graph_conv.weight = weight + + seq = torch.nn.Sequential(in_layer, d_layer) + out_d_layer = seq(None) + out_multi_dgca = multi_dgca(in_layer(None)) + + assert isinstance(out_d_layer, list) + assert len(out_d_layer) == 1 + + assert torch.all(out_d_layer[0] == out_multi_dgca) + + +def test_decagon_layer_05(): + # check if it is equivalent to MultiDGCA, as it should be + # this time for two relations, same edge type + + d = Data() + d.add_node_type('Dummy', 100) + d.add_relation_type('Dummy Relation 1', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + d.add_relation_type('Dummy Relation 2', 0, 0, + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(d) + + multi_dgca = MultiDGCA([100, 100], 32, + [r.adjacency_matrix for r in d.relation_types[0][0]], + keep_prob=1., activation=lambda x: x) + + d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, + keep_prob=1., rel_activation=lambda x: x, + layer_activation=lambda x: x) + + assert all([ + isinstance(dgca, DropoutGraphConvActivation) \ + for dgca in d_layer.next_layer_repr[0][0].convolutions + ]) + + weight = [ dgca.graph_conv.weight \ + for dgca in d_layer.next_layer_repr[0][0].convolutions ] + assert all([ + isinstance(w, torch.Tensor) \ + for w in weight + ]) + + assert len(multi_dgca.dgca) == 2 + for i in range(2): + assert isinstance(multi_dgca.dgca[i], + decagon_pytorch.convolve.DropoutGraphConvActivation) + + for i in range(2): + multi_dgca.dgca[i].graph_conv.weight = weight[i] + + seq = torch.nn.Sequential(in_layer, d_layer) + out_d_layer = seq(None) + x = in_layer(None) + out_multi_dgca = multi_dgca([ x[0], x[0] ]) + + assert isinstance(out_d_layer, list) + assert len(out_d_layer) == 1 + + assert torch.all(out_d_layer[0] == out_multi_dgca)