From 5464c5a4c1fe01c4e04b1fae6b402c54f0248924 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 19 May 2020 15:28:35 +0200 Subject: [PATCH] Add InputLayer. --- src/decagon_pytorch/layer.py | 60 +++++++++++++++++++++++++++++ tests/decagon_pytorch/test_data.py | 1 - tests/decagon_pytorch/test_layer.py | 52 +++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/decagon_pytorch/layer.py create mode 100644 tests/decagon_pytorch/test_layer.py diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py new file mode 100644 index 0000000..3929062 --- /dev/null +++ b/src/decagon_pytorch/layer.py @@ -0,0 +1,60 @@ +# +# This module implements a single layer of the Decagon +# model. This is going to be already quite complex, as +# we will be using all the graph convolutional building +# blocks. +# +# h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \ +# W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k)) +# +# N{r}^{i} - set of neighbors of node i under relation r +# W_{r}^(k) - relation-type specific weight matrix +# h_{i}^(k) - hidden state of node i in layer k +# h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality +# of the representation in k-th layer +# ϕ - activation function +# c_{r}^{ij} - normalization constants +# c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) +# c_{r}^{i} - normalization constants +# c_{r}^{i} = 1/|N_{r}^{i}| +# + + +import torch + + +class InputLayer(torch.nn.Module): + def __init__(self, data, dimensionality=32, **kwargs): + super().__init__(**kwargs) + self.data = data + self.dimensionality = dimensionality + self.node_reps = None + self.build() + + def build(self): + self.node_reps = [] + for i, nt in enumerate(self.data.node_types): + reps = torch.rand(nt.count, self.dimensionality) + reps = torch.nn.Parameter(reps) + self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self): + return self.node_reps + + def __repr__(self): + s = '' + s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() + + +class DecagonLayer(torch.nn.Module): + def __init__(self, data, **kwargs): + super().__init__(**kwargs) + self.data = data + + def __call__(self, previous_layer): + pass diff --git a/tests/decagon_pytorch/test_data.py b/tests/decagon_pytorch/test_data.py index fc6c111..51426ee 100644 --- a/tests/decagon_pytorch/test_data.py +++ b/tests/decagon_pytorch/test_data.py @@ -1,5 +1,4 @@ from decagon_pytorch.data import Data -from decagon_pytorch.decode import DEDICOMDecoder def test_data(): diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py new file mode 100644 index 0000000..7e2115c --- /dev/null +++ b/tests/decagon_pytorch/test_layer.py @@ -0,0 +1,52 @@ +from decagon_pytorch.layer import InputLayer +from decagon_pytorch.data import Data +import torch +import pytest + + +def _some_data(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, None) + d.add_relation_type('Interaction', 0, 0, None) + d.add_relation_type('Side Effect: Nausea', 1, 1, None) + d.add_relation_type('Side Effect: Infertility', 1, 1, None) + d.add_relation_type('Side Effect: Death', 1, 1, None) + return d + + +def test_input_layer_01(): + d = _some_data() + for dimensionality in [32, 64, 128]: + layer = InputLayer(d, dimensionality) + assert layer.dimensionality == dimensionality + assert len(layer.node_reps) == 2 + assert layer.node_reps[0].shape == (1000, dimensionality) + assert layer.node_reps[1].shape == (100, dimensionality) + assert layer.data == d + + +def test_input_layer_02(): + d = _some_data() + layer = InputLayer(d, 32) + res = layer() + assert isinstance(res[0], torch.Tensor) + assert isinstance(res[1], torch.Tensor) + assert res[0].shape == (1000, 32) + assert res[1].shape == (100, 32) + assert torch.all(res[0] == layer.node_reps[0]) + assert torch.all(res[1] == layer.node_reps[1]) + + +def test_input_layer_03(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA devices on this host') + d = _some_data() + layer = InputLayer(d, 32) + device = torch.device('cuda:0') + layer = layer.to(device) + print(list(layer.parameters())) + # assert layer.device.type == 'cuda:0' + assert layer.node_reps[0].device == device + assert layer.node_reps[1].device == device