From c45e0fa9f11e49260d846579717ed1acd9d18a4a Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Mon, 25 May 2020 17:46:03 +0200 Subject: [PATCH] Make InputLayer support variable dimensionality representations. --- src/decagon_pytorch/layer.py | 8 ++++++-- tests/decagon_pytorch/test_layer.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index a75c005..f21c182 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -25,9 +25,13 @@ from .convolve import SparseMultiDGCA class InputLayer(torch.nn.Module): - def __init__(self, data, dimensionality=32, **kwargs): + def __init__(self, data, dimensionality=None, **kwargs): super().__init__(**kwargs) self.data = data + dimensionality = dimensionality or \ + list(map(lambda a: a.count, data.node_types)) + if not isinstance(dimensionality, list): + dimensionality = [dimensionality,] * len(self.data.node_types) self.dimensionality = dimensionality self.node_reps = None self.build() @@ -35,7 +39,7 @@ class InputLayer(torch.nn.Module): def build(self): self.node_reps = [] for i, nt in enumerate(self.data.node_types): - reps = torch.rand(nt.count, self.dimensionality) + reps = torch.rand(nt.count, self.dimensionality[i]) reps = torch.nn.Parameter(reps) self.register_parameter('node_reps[%d]' % i, reps) self.node_reps.append(reps) diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 7e2115c..c171e1c 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -1,4 +1,5 @@ -from decagon_pytorch.layer import InputLayer +from decagon_pytorch.layer import InputLayer, \ + DecagonLayer from decagon_pytorch.data import Data import torch import pytest @@ -16,11 +17,28 @@ def _some_data(): return d +def _some_data_with_interactions(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, + torch.rand((100, 1000), dtype=torch.float32).round()) + d.add_relation_type('Interaction', 0, 0, + torch.rand((1000, 1000), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Nausea', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Infertility', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + d.add_relation_type('Side Effect: Death', 1, 1, + torch.rand((100, 100), dtype=torch.float32).round()) + return d + + def test_input_layer_01(): d = _some_data() for dimensionality in [32, 64, 128]: layer = InputLayer(d, dimensionality) - assert layer.dimensionality == dimensionality + assert layer.dimensionality[0] == dimensionality assert len(layer.node_reps) == 2 assert layer.node_reps[0].shape == (1000, dimensionality) assert layer.node_reps[1].shape == (100, dimensionality) @@ -50,3 +68,10 @@ def test_input_layer_03(): # assert layer.device.type == 'cuda:0' assert layer.node_reps[0].device == device assert layer.node_reps[1].device == device + + +@pytest.mark.skip() +def test_decagon_layer_01(): + d = _some_data_with_interactions() + in_layer = InputLayer(d) + d_layer = DecagonLayer(in_layer, output_dim=32)