@@ -0,0 +1,60 @@ | |||
# | |||
# This module implements a single layer of the Decagon | |||
# model. This is going to be already quite complex, as | |||
# we will be using all the graph convolutional building | |||
# blocks. | |||
# | |||
# h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \ | |||
# W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k)) | |||
# | |||
# N{r}^{i} - set of neighbors of node i under relation r | |||
# W_{r}^(k) - relation-type specific weight matrix | |||
# h_{i}^(k) - hidden state of node i in layer k | |||
# h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality | |||
# of the representation in k-th layer | |||
# ϕ - activation function | |||
# c_{r}^{ij} - normalization constants | |||
# c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||
# c_{r}^{i} - normalization constants | |||
# c_{r}^{i} = 1/|N_{r}^{i}| | |||
# | |||
import torch | |||
class InputLayer(torch.nn.Module): | |||
def __init__(self, data, dimensionality=32, **kwargs): | |||
super().__init__(**kwargs) | |||
self.data = data | |||
self.dimensionality = dimensionality | |||
self.node_reps = None | |||
self.build() | |||
def build(self): | |||
self.node_reps = [] | |||
for i, nt in enumerate(self.data.node_types): | |||
reps = torch.rand(nt.count, self.dimensionality) | |||
reps = torch.nn.Parameter(reps) | |||
self.register_parameter('node_reps[%d]' % i, reps) | |||
self.node_reps.append(reps) | |||
def forward(self): | |||
return self.node_reps | |||
def __repr__(self): | |||
s = '' | |||
s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality | |||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||
for nt in self.data.node_types: | |||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
return s.strip() | |||
class DecagonLayer(torch.nn.Module): | |||
def __init__(self, data, **kwargs): | |||
super().__init__(**kwargs) | |||
self.data = data | |||
def __call__(self, previous_layer): | |||
pass |
@@ -1,5 +1,4 @@ | |||
from decagon_pytorch.data import Data | |||
from decagon_pytorch.decode import DEDICOMDecoder | |||
def test_data(): | |||
@@ -0,0 +1,52 @@ | |||
from decagon_pytorch.layer import InputLayer | |||
from decagon_pytorch.data import Data | |||
import torch | |||
import pytest | |||
def _some_data(): | |||
d = Data() | |||
d.add_node_type('Gene', 1000) | |||
d.add_node_type('Drug', 100) | |||
d.add_relation_type('Target', 1, 0, None) | |||
d.add_relation_type('Interaction', 0, 0, None) | |||
d.add_relation_type('Side Effect: Nausea', 1, 1, None) | |||
d.add_relation_type('Side Effect: Infertility', 1, 1, None) | |||
d.add_relation_type('Side Effect: Death', 1, 1, None) | |||
return d | |||
def test_input_layer_01(): | |||
d = _some_data() | |||
for dimensionality in [32, 64, 128]: | |||
layer = InputLayer(d, dimensionality) | |||
assert layer.dimensionality == dimensionality | |||
assert len(layer.node_reps) == 2 | |||
assert layer.node_reps[0].shape == (1000, dimensionality) | |||
assert layer.node_reps[1].shape == (100, dimensionality) | |||
assert layer.data == d | |||
def test_input_layer_02(): | |||
d = _some_data() | |||
layer = InputLayer(d, 32) | |||
res = layer() | |||
assert isinstance(res[0], torch.Tensor) | |||
assert isinstance(res[1], torch.Tensor) | |||
assert res[0].shape == (1000, 32) | |||
assert res[1].shape == (100, 32) | |||
assert torch.all(res[0] == layer.node_reps[0]) | |||
assert torch.all(res[1] == layer.node_reps[1]) | |||
def test_input_layer_03(): | |||
if torch.cuda.device_count() == 0: | |||
pytest.skip('No CUDA devices on this host') | |||
d = _some_data() | |||
layer = InputLayer(d, 32) | |||
device = torch.device('cuda:0') | |||
layer = layer.to(device) | |||
print(list(layer.parameters())) | |||
# assert layer.device.type == 'cuda:0' | |||
assert layer.node_reps[0].device == device | |||
assert layer.node_reps[1].device == device |