From c37e4dc01e0623a3a6f6dd1490335a0edd3e4624 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 29 May 2020 15:23:58 +0200 Subject: [PATCH] Ok the very first decoding seems to work end-to-end. --- src/decagon_pytorch/__init__.py | 1 + src/decagon_pytorch/layer/decode.py | 60 +++++++ tests/decagon_pytorch/test_layer.py | 241 ---------------------------- 3 files changed, 61 insertions(+), 241 deletions(-) delete mode 100644 tests/decagon_pytorch/test_layer.py diff --git a/src/decagon_pytorch/__init__.py b/src/decagon_pytorch/__init__.py index f628a28..bc82355 100644 --- a/src/decagon_pytorch/__init__.py +++ b/src/decagon_pytorch/__init__.py @@ -1,3 +1,4 @@ from .weights import * from .convolve import * from .model import * +from .layer import * diff --git a/src/decagon_pytorch/layer/decode.py b/src/decagon_pytorch/layer/decode.py index e69de29..d0d548b 100644 --- a/src/decagon_pytorch/layer/decode.py +++ b/src/decagon_pytorch/layer/decode.py @@ -0,0 +1,60 @@ +from .layer import Layer +import torch +from ..data import Data +from typing import Type, \ + List, \ + Callable, \ + Union, \ + Dict, \ + Tuple +from ..decode import DEDICOMDecoder + + +class DecodeLayer(torch.nn.Module): + def __init__(self, + data: Data, + last_layer: Layer, + keep_prob: float = 1., + activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, + decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, **kwargs) -> None: + + super().__init__(**kwargs) + self.data = data + self.last_layer = last_layer + self.keep_prob = keep_prob + self.activation = activation + assert all([a == last_layer.output_dim[0] \ + for a in last_layer.output_dim]) + self.input_dim = last_layer.output_dim[0] + self.output_dim = 1 + self.decoder_class = decoder_class + self.decoders = None + self.build() + + def build(self) -> None: + self.decoders = {} + for (node_type_row, node_type_col), rels in self.data.relation_types.items(): + key = (node_type_row, node_type_col) + if isinstance(self.decoder_class, dict): + if key in self.decoder_class: + decoder_class = self.decoder_class[key] + else: + raise KeyError('Decoder not specified for edge type: %d -- %d' % key) + else: + decoder_class = self.decoder_class + + self.decoders[key] = decoder_class(self.input_dim, + num_relation_types = len(rels), + drop_prob = 1. - self.keep_prob, + activation = self.activation) + + + def forward(self, last_layer_repr: List[torch.Tensor]): + res = {} + for (node_type_row, node_type_col), rel in self.data.relation_types.items(): + key = (node_type_row, node_type_col) + inputs_row = last_layer_repr[node_type_row] + inputs_col = last_layer_repr[node_type_col] + pred_adj_matrices = self.decoders[key](inputs_row, inputs_col) + res[node_type_row, node_type_col] = pred_adj_matrices + return res diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py deleted file mode 100644 index 0b6207f..0000000 --- a/tests/decagon_pytorch/test_layer.py +++ /dev/null @@ -1,241 +0,0 @@ -from decagon_pytorch.layer import InputLayer, \ - OneHotInputLayer, \ - DecagonLayer -from decagon_pytorch.data import Data -import torch -import pytest -from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \ - SparseMultiDGCA, \ - DropoutGraphConvActivation - - -def _some_data(): - d = Data() - d.add_node_type('Gene', 1000) - d.add_node_type('Drug', 100) - d.add_relation_type('Target', 1, 0, None) - d.add_relation_type('Interaction', 0, 0, None) - d.add_relation_type('Side Effect: Nausea', 1, 1, None) - d.add_relation_type('Side Effect: Infertility', 1, 1, None) - d.add_relation_type('Side Effect: Death', 1, 1, None) - return d - - -def _some_data_with_interactions(): - d = Data() - d.add_node_type('Gene', 1000) - d.add_node_type('Drug', 100) - d.add_relation_type('Target', 1, 0, - torch.rand((100, 1000), dtype=torch.float32).round()) - d.add_relation_type('Interaction', 0, 0, - torch.rand((1000, 1000), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Nausea', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Infertility', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Death', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) - return d - - -def test_input_layer_01(): - d = _some_data() - for output_dim in [32, 64, 128]: - layer = InputLayer(d, output_dim) - assert layer.output_dim[0] == output_dim - assert len(layer.node_reps) == 2 - assert layer.node_reps[0].shape == (1000, output_dim) - assert layer.node_reps[1].shape == (100, output_dim) - assert layer.data == d - - -def test_input_layer_02(): - d = _some_data() - layer = InputLayer(d, 32) - res = layer() - assert isinstance(res[0], torch.Tensor) - assert isinstance(res[1], torch.Tensor) - assert res[0].shape == (1000, 32) - assert res[1].shape == (100, 32) - assert torch.all(res[0] == layer.node_reps[0]) - assert torch.all(res[1] == layer.node_reps[1]) - - -def test_input_layer_03(): - if torch.cuda.device_count() == 0: - pytest.skip('No CUDA devices on this host') - d = _some_data() - layer = InputLayer(d, 32) - device = torch.device('cuda:0') - layer = layer.to(device) - print(list(layer.parameters())) - # assert layer.device.type == 'cuda:0' - assert layer.node_reps[0].device == device - assert layer.node_reps[1].device == device - - -def test_one_hot_input_layer_01(): - d = _some_data() - layer = OneHotInputLayer(d) - assert layer.output_dim == [1000, 100] - assert len(layer.node_reps) == 2 - assert layer.node_reps[0].shape == (1000, 1000) - assert layer.node_reps[1].shape == (100, 100) - assert layer.data == d - assert layer.is_sparse - - -def test_one_hot_input_layer_02(): - d = _some_data() - layer = OneHotInputLayer(d) - res = layer() - assert isinstance(res[0], torch.Tensor) - assert isinstance(res[1], torch.Tensor) - assert res[0].shape == (1000, 1000) - assert res[1].shape == (100, 100) - assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense()) - assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense()) - - -def test_one_hot_input_layer_03(): - if torch.cuda.device_count() == 0: - pytest.skip('No CUDA devices on this host') - d = _some_data() - layer = OneHotInputLayer(d) - device = torch.device('cuda:0') - layer = layer.to(device) - print(list(layer.parameters())) - # assert layer.device.type == 'cuda:0' - assert layer.node_reps[0].device == device - assert layer.node_reps[1].device == device - - -def test_decagon_layer_01(): - d = _some_data_with_interactions() - in_layer = InputLayer(d) - d_layer = DecagonLayer(d, in_layer, output_dim=32) - - -def test_decagon_layer_02(): - d = _some_data_with_interactions() - in_layer = OneHotInputLayer(d) - d_layer = DecagonLayer(d, in_layer, output_dim=32) - _ = d_layer() # dummy call - - -def test_decagon_layer_03(): - d = _some_data_with_interactions() - in_layer = OneHotInputLayer(d) - d_layer = DecagonLayer(d, in_layer, output_dim=32) - assert d_layer.data == d - assert d_layer.previous_layer == in_layer - assert d_layer.input_dim == [ 1000, 100 ] - assert not d_layer.is_sparse - assert d_layer.keep_prob == 1. - assert d_layer.rel_activation(0.5) == 0.5 - x = torch.tensor([-1, 0, 0.5, 1]) - assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all() - assert len(d_layer.next_layer_repr) == 2 - - for i in range(2): - assert len(d_layer.next_layer_repr[i]) == 2 - assert isinstance(d_layer.next_layer_repr[i], list) - assert isinstance(d_layer.next_layer_repr[i][0], tuple) - assert isinstance(d_layer.next_layer_repr[i][0][0], list) - assert isinstance(d_layer.next_layer_repr[i][0][1], int) - assert all([ - isinstance(dgca, DropoutGraphConvActivation) \ - for dgca in d_layer.next_layer_repr[i][0][0] - ]) - assert all([ - dgca.output_dim == 32 \ - for dgca in d_layer.next_layer_repr[i][0][0] - ]) - - -def test_decagon_layer_04(): - # check if it is equivalent to MultiDGCA, as it should be - - d = Data() - d.add_node_type('Dummy', 100) - d.add_relation_type('Dummy Relation', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) - - in_layer = OneHotInputLayer(d) - - multi_dgca = SparseMultiDGCA([10], 32, - [r.adjacency_matrix for r in d.relation_types[0, 0]], - keep_prob=1., activation=lambda x: x) - - d_layer = DecagonLayer(d, in_layer, output_dim=32, - keep_prob=1., rel_activation=lambda x: x, - layer_activation=lambda x: x) - - assert isinstance(d_layer.next_layer_repr[0][0][0][0], - DropoutGraphConvActivation) - - weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight - assert isinstance(weight, torch.Tensor) - - assert len(multi_dgca.sparse_dgca) == 1 - assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation) - - multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight - - out_d_layer = d_layer() - out_multi_dgca = multi_dgca(in_layer()) - - assert isinstance(out_d_layer, list) - assert len(out_d_layer) == 1 - - assert torch.all(out_d_layer[0] == out_multi_dgca) - - -def test_decagon_layer_05(): - # check if it is equivalent to MultiDGCA, as it should be - # this time for two relations, same edge type - - d = Data() - d.add_node_type('Dummy', 100) - d.add_relation_type('Dummy Relation 1', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) - d.add_relation_type('Dummy Relation 2', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) - - in_layer = OneHotInputLayer(d) - - multi_dgca = SparseMultiDGCA([100, 100], 32, - [r.adjacency_matrix for r in d.relation_types[0, 0]], - keep_prob=1., activation=lambda x: x) - - d_layer = DecagonLayer(d, in_layer, output_dim=32, - keep_prob=1., rel_activation=lambda x: x, - layer_activation=lambda x: x) - - assert all([ - isinstance(dgca, DropoutGraphConvActivation) \ - for dgca in d_layer.next_layer_repr[0][0][0] - ]) - - weight = [ dgca.graph_conv.weight \ - for dgca in d_layer.next_layer_repr[0][0][0] ] - assert all([ - isinstance(w, torch.Tensor) \ - for w in weight - ]) - - assert len(multi_dgca.sparse_dgca) == 2 - for i in range(2): - assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation) - - for i in range(2): - multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i] - - out_d_layer = d_layer() - x = in_layer() - out_multi_dgca = multi_dgca([ x[0], x[0] ]) - - assert isinstance(out_d_layer, list) - assert len(out_d_layer) == 1 - - assert torch.all(out_d_layer[0] == out_multi_dgca)