IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Make InputLayer support variable dimensionality representations.

master
Stanislaw Adaszewski 4 years ago
parent
commit
c45e0fa9f1
2 changed files with 33 additions and 4 deletions
  1. +6
    -2
      src/decagon_pytorch/layer.py
  2. +27
    -2
      tests/decagon_pytorch/test_layer.py

+ 6
- 2
src/decagon_pytorch/layer.py View File

@@ -25,9 +25,13 @@ from .convolve import SparseMultiDGCA
class InputLayer(torch.nn.Module): class InputLayer(torch.nn.Module):
def __init__(self, data, dimensionality=32, **kwargs):
def __init__(self, data, dimensionality=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.data = data self.data = data
dimensionality = dimensionality or \
list(map(lambda a: a.count, data.node_types))
if not isinstance(dimensionality, list):
dimensionality = [dimensionality,] * len(self.data.node_types)
self.dimensionality = dimensionality self.dimensionality = dimensionality
self.node_reps = None self.node_reps = None
self.build() self.build()
@@ -35,7 +39,7 @@ class InputLayer(torch.nn.Module):
def build(self): def build(self):
self.node_reps = [] self.node_reps = []
for i, nt in enumerate(self.data.node_types): for i, nt in enumerate(self.data.node_types):
reps = torch.rand(nt.count, self.dimensionality)
reps = torch.rand(nt.count, self.dimensionality[i])
reps = torch.nn.Parameter(reps) reps = torch.nn.Parameter(reps)
self.register_parameter('node_reps[%d]' % i, reps) self.register_parameter('node_reps[%d]' % i, reps)
self.node_reps.append(reps) self.node_reps.append(reps)


+ 27
- 2
tests/decagon_pytorch/test_layer.py View File

@@ -1,4 +1,5 @@
from decagon_pytorch.layer import InputLayer
from decagon_pytorch.layer import InputLayer, \
DecagonLayer
from decagon_pytorch.data import Data from decagon_pytorch.data import Data
import torch import torch
import pytest import pytest
@@ -16,11 +17,28 @@ def _some_data():
return d return d
def _some_data_with_interactions():
d = Data()
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
d.add_relation_type('Target', 1, 0,
torch.rand((100, 1000), dtype=torch.float32).round())
d.add_relation_type('Interaction', 0, 0,
torch.rand((1000, 1000), dtype=torch.float32).round())
d.add_relation_type('Side Effect: Nausea', 1, 1,
torch.rand((100, 100), dtype=torch.float32).round())
d.add_relation_type('Side Effect: Infertility', 1, 1,
torch.rand((100, 100), dtype=torch.float32).round())
d.add_relation_type('Side Effect: Death', 1, 1,
torch.rand((100, 100), dtype=torch.float32).round())
return d
def test_input_layer_01(): def test_input_layer_01():
d = _some_data() d = _some_data()
for dimensionality in [32, 64, 128]: for dimensionality in [32, 64, 128]:
layer = InputLayer(d, dimensionality) layer = InputLayer(d, dimensionality)
assert layer.dimensionality == dimensionality
assert layer.dimensionality[0] == dimensionality
assert len(layer.node_reps) == 2 assert len(layer.node_reps) == 2
assert layer.node_reps[0].shape == (1000, dimensionality) assert layer.node_reps[0].shape == (1000, dimensionality)
assert layer.node_reps[1].shape == (100, dimensionality) assert layer.node_reps[1].shape == (100, dimensionality)
@@ -50,3 +68,10 @@ def test_input_layer_03():
# assert layer.device.type == 'cuda:0' # assert layer.device.type == 'cuda:0'
assert layer.node_reps[0].device == device assert layer.node_reps[0].device == device
assert layer.node_reps[1].device == device assert layer.node_reps[1].device == device
@pytest.mark.skip()
def test_decagon_layer_01():
d = _some_data_with_interactions()
in_layer = InputLayer(d)
d_layer = DecagonLayer(in_layer, output_dim=32)

Loading…
Cancel
Save