@@ -0,0 +1,60 @@ | |||||
# | |||||
# This module implements a single layer of the Decagon | |||||
# model. This is going to be already quite complex, as | |||||
# we will be using all the graph convolutional building | |||||
# blocks. | |||||
# | |||||
# h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \ | |||||
# W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k)) | |||||
# | |||||
# N{r}^{i} - set of neighbors of node i under relation r | |||||
# W_{r}^(k) - relation-type specific weight matrix | |||||
# h_{i}^(k) - hidden state of node i in layer k | |||||
# h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality | |||||
# of the representation in k-th layer | |||||
# ϕ - activation function | |||||
# c_{r}^{ij} - normalization constants | |||||
# c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||||
# c_{r}^{i} - normalization constants | |||||
# c_{r}^{i} = 1/|N_{r}^{i}| | |||||
# | |||||
import torch | |||||
class InputLayer(torch.nn.Module): | |||||
def __init__(self, data, dimensionality=32, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.data = data | |||||
self.dimensionality = dimensionality | |||||
self.node_reps = None | |||||
self.build() | |||||
def build(self): | |||||
self.node_reps = [] | |||||
for i, nt in enumerate(self.data.node_types): | |||||
reps = torch.rand(nt.count, self.dimensionality) | |||||
reps = torch.nn.Parameter(reps) | |||||
self.register_parameter('node_reps[%d]' % i, reps) | |||||
self.node_reps.append(reps) | |||||
def forward(self): | |||||
return self.node_reps | |||||
def __repr__(self): | |||||
s = '' | |||||
s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality | |||||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
for nt in self.data.node_types: | |||||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
return s.strip() | |||||
class DecagonLayer(torch.nn.Module): | |||||
def __init__(self, data, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.data = data | |||||
def __call__(self, previous_layer): | |||||
pass |
@@ -1,5 +1,4 @@ | |||||
from decagon_pytorch.data import Data | from decagon_pytorch.data import Data | ||||
from decagon_pytorch.decode import DEDICOMDecoder | |||||
def test_data(): | def test_data(): | ||||
@@ -0,0 +1,52 @@ | |||||
from decagon_pytorch.layer import InputLayer | |||||
from decagon_pytorch.data import Data | |||||
import torch | |||||
import pytest | |||||
def _some_data(): | |||||
d = Data() | |||||
d.add_node_type('Gene', 1000) | |||||
d.add_node_type('Drug', 100) | |||||
d.add_relation_type('Target', 1, 0, None) | |||||
d.add_relation_type('Interaction', 0, 0, None) | |||||
d.add_relation_type('Side Effect: Nausea', 1, 1, None) | |||||
d.add_relation_type('Side Effect: Infertility', 1, 1, None) | |||||
d.add_relation_type('Side Effect: Death', 1, 1, None) | |||||
return d | |||||
def test_input_layer_01(): | |||||
d = _some_data() | |||||
for dimensionality in [32, 64, 128]: | |||||
layer = InputLayer(d, dimensionality) | |||||
assert layer.dimensionality == dimensionality | |||||
assert len(layer.node_reps) == 2 | |||||
assert layer.node_reps[0].shape == (1000, dimensionality) | |||||
assert layer.node_reps[1].shape == (100, dimensionality) | |||||
assert layer.data == d | |||||
def test_input_layer_02(): | |||||
d = _some_data() | |||||
layer = InputLayer(d, 32) | |||||
res = layer() | |||||
assert isinstance(res[0], torch.Tensor) | |||||
assert isinstance(res[1], torch.Tensor) | |||||
assert res[0].shape == (1000, 32) | |||||
assert res[1].shape == (100, 32) | |||||
assert torch.all(res[0] == layer.node_reps[0]) | |||||
assert torch.all(res[1] == layer.node_reps[1]) | |||||
def test_input_layer_03(): | |||||
if torch.cuda.device_count() == 0: | |||||
pytest.skip('No CUDA devices on this host') | |||||
d = _some_data() | |||||
layer = InputLayer(d, 32) | |||||
device = torch.device('cuda:0') | |||||
layer = layer.to(device) | |||||
print(list(layer.parameters())) | |||||
# assert layer.device.type == 'cuda:0' | |||||
assert layer.node_reps[0].device == device | |||||
assert layer.node_reps[1].device == device |