From d4dd1f29230a0b05fc27a04d806b487cac147ec8 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sun, 7 Jun 2020 12:51:22 +0200 Subject: [PATCH] Add input to icosagon. --- src/icosagon/{layer.py => convlayer.py} | 0 src/icosagon/declayer.py | 2 + src/icosagon/input.py | 76 +++++++++++++++++ tests/icosagon/test_input.py | 106 ++++++++++++++++++++++++ 4 files changed, 184 insertions(+) rename src/icosagon/{layer.py => convlayer.py} (100%) create mode 100644 src/icosagon/declayer.py create mode 100644 src/icosagon/input.py create mode 100644 tests/icosagon/test_input.py diff --git a/src/icosagon/layer.py b/src/icosagon/convlayer.py similarity index 100% rename from src/icosagon/layer.py rename to src/icosagon/convlayer.py diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py new file mode 100644 index 0000000..46bffca --- /dev/null +++ b/src/icosagon/declayer.py @@ -0,0 +1,2 @@ +# from .layer import DecagonLayer +# from .input import OneHotInputLayer diff --git a/src/icosagon/input.py b/src/icosagon/input.py new file mode 100644 index 0000000..c0b2672 --- /dev/null +++ b/src/icosagon/input.py @@ -0,0 +1,76 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from typing import Union, \ + List +from .data import Data + + +class InputLayer(torch.nn.Module): + def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, **kwargs) -> None: + output_dim = output_dim or \ + list(map(lambda a: a.count, data.node_types)) + if not isinstance(output_dim, list): + output_dim = [output_dim,] * len(data.node_types) + + super().__init__(**kwargs) + self.output_dim = output_dim + self.data = data + + self.is_sparse=False + self.node_reps = None + self.build() + + def build(self) -> None: + self.node_reps = [] + for i, nt in enumerate(self.data.node_types): + reps = torch.rand(nt.count, self.output_dim[i]) + reps = torch.nn.Parameter(reps) + self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self, x) -> List[torch.nn.Parameter]: + return self.node_reps + + def __repr__(self) -> str: + s = '' + s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() + + +class OneHotInputLayer(torch.nn.Module): + def __init__(self, data: Data, **kwargs) -> None: + output_dim = [ a.count for a in data.node_types ] + super().__init__(**kwargs) + self.output_dim = output_dim + self.data = data + + self.is_sparse=True + self.node_reps = None + self.build() + + def build(self) -> None: + self.node_reps = [] + for i, nt in enumerate(self.data.node_types): + reps = torch.eye(nt.count).to_sparse() + reps = torch.nn.Parameter(reps) + self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self, x) -> List[torch.nn.Parameter]: + return self.node_reps + + def __repr__(self) -> str: + s = '' + s += 'One-hot Icosagon input layer\n' + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() diff --git a/tests/icosagon/test_input.py b/tests/icosagon/test_input.py new file mode 100644 index 0000000..3e73ce1 --- /dev/null +++ b/tests/icosagon/test_input.py @@ -0,0 +1,106 @@ +from icosagon.input import InputLayer, \ + OneHotInputLayer +from icosagon.data import Data +import torch +import pytest + + +def _some_data(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, torch.rand(100, 1000)) + d.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000)) + d.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100)) + d.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100)) + d.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100)) + return d + + +def _some_data_with_interactions(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, + torch.rand((100, 1000), dtype=torch.float32).round()) + d.add_relation_type('Interaction', 0, 0, + torch.rand((1000, 1000), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Nausea', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Infertility', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Death', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + return d + + +def test_input_layer_01(): + d = _some_data() + for output_dim in [32, 64, 128]: + layer = InputLayer(d, output_dim) + assert layer.output_dim[0] == output_dim + assert len(layer.node_reps) == 2 + assert layer.node_reps[0].shape == (1000, output_dim) + assert layer.node_reps[1].shape == (100, output_dim) + assert layer.data == d + + +def test_input_layer_02(): + d = _some_data() + layer = InputLayer(d, 32) + res = layer(None) + assert isinstance(res[0], torch.Tensor) + assert isinstance(res[1], torch.Tensor) + assert res[0].shape == (1000, 32) + assert res[1].shape == (100, 32) + assert torch.all(res[0] == layer.node_reps[0]) + assert torch.all(res[1] == layer.node_reps[1]) + + +def test_input_layer_03(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA devices on this host') + d = _some_data() + layer = InputLayer(d, 32) + device = torch.device('cuda:0') + layer = layer.to(device) + print(list(layer.parameters())) + # assert layer.device.type == 'cuda:0' + assert layer.node_reps[0].device == device + assert layer.node_reps[1].device == device + + +def test_one_hot_input_layer_01(): + d = _some_data() + layer = OneHotInputLayer(d) + assert layer.output_dim == [1000, 100] + assert len(layer.node_reps) == 2 + assert layer.node_reps[0].shape == (1000, 1000) + assert layer.node_reps[1].shape == (100, 100) + assert layer.data == d + assert layer.is_sparse + + +def test_one_hot_input_layer_02(): + d = _some_data() + layer = OneHotInputLayer(d) + res = layer(None) + assert isinstance(res[0], torch.Tensor) + assert isinstance(res[1], torch.Tensor) + assert res[0].shape == (1000, 1000) + assert res[1].shape == (100, 100) + assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense()) + assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense()) + + +def test_one_hot_input_layer_03(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA devices on this host') + d = _some_data() + layer = OneHotInputLayer(d) + device = torch.device('cuda:0') + layer = layer.to(device) + print(list(layer.parameters())) + # assert layer.device.type == 'cuda:0' + assert layer.node_reps[0].device == device + assert layer.node_reps[1].device == device