| @@ -0,0 +1,60 @@ | |||||
| # | |||||
| # This module implements a single layer of the Decagon | |||||
| # model. This is going to be already quite complex, as | |||||
| # we will be using all the graph convolutional building | |||||
| # blocks. | |||||
| # | |||||
| # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \ | |||||
| # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k)) | |||||
| # | |||||
| # N{r}^{i} - set of neighbors of node i under relation r | |||||
| # W_{r}^(k) - relation-type specific weight matrix | |||||
| # h_{i}^(k) - hidden state of node i in layer k | |||||
| # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality | |||||
| # of the representation in k-th layer | |||||
| # ϕ - activation function | |||||
| # c_{r}^{ij} - normalization constants | |||||
| # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||||
| # c_{r}^{i} - normalization constants | |||||
| # c_{r}^{i} = 1/|N_{r}^{i}| | |||||
| # | |||||
| import torch | |||||
| class InputLayer(torch.nn.Module): | |||||
| def __init__(self, data, dimensionality=32, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.data = data | |||||
| self.dimensionality = dimensionality | |||||
| self.node_reps = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.node_reps = [] | |||||
| for i, nt in enumerate(self.data.node_types): | |||||
| reps = torch.rand(nt.count, self.dimensionality) | |||||
| reps = torch.nn.Parameter(reps) | |||||
| self.register_parameter('node_reps[%d]' % i, reps) | |||||
| self.node_reps.append(reps) | |||||
| def forward(self): | |||||
| return self.node_reps | |||||
| def __repr__(self): | |||||
| s = '' | |||||
| s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality | |||||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
| for nt in self.data.node_types: | |||||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
| return s.strip() | |||||
| class DecagonLayer(torch.nn.Module): | |||||
| def __init__(self, data, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.data = data | |||||
| def __call__(self, previous_layer): | |||||
| pass | |||||
| @@ -1,5 +1,4 @@ | |||||
| from decagon_pytorch.data import Data | from decagon_pytorch.data import Data | ||||
| from decagon_pytorch.decode import DEDICOMDecoder | |||||
| def test_data(): | def test_data(): | ||||
| @@ -0,0 +1,52 @@ | |||||
| from decagon_pytorch.layer import InputLayer | |||||
| from decagon_pytorch.data import Data | |||||
| import torch | |||||
| import pytest | |||||
| def _some_data(): | |||||
| d = Data() | |||||
| d.add_node_type('Gene', 1000) | |||||
| d.add_node_type('Drug', 100) | |||||
| d.add_relation_type('Target', 1, 0, None) | |||||
| d.add_relation_type('Interaction', 0, 0, None) | |||||
| d.add_relation_type('Side Effect: Nausea', 1, 1, None) | |||||
| d.add_relation_type('Side Effect: Infertility', 1, 1, None) | |||||
| d.add_relation_type('Side Effect: Death', 1, 1, None) | |||||
| return d | |||||
| def test_input_layer_01(): | |||||
| d = _some_data() | |||||
| for dimensionality in [32, 64, 128]: | |||||
| layer = InputLayer(d, dimensionality) | |||||
| assert layer.dimensionality == dimensionality | |||||
| assert len(layer.node_reps) == 2 | |||||
| assert layer.node_reps[0].shape == (1000, dimensionality) | |||||
| assert layer.node_reps[1].shape == (100, dimensionality) | |||||
| assert layer.data == d | |||||
| def test_input_layer_02(): | |||||
| d = _some_data() | |||||
| layer = InputLayer(d, 32) | |||||
| res = layer() | |||||
| assert isinstance(res[0], torch.Tensor) | |||||
| assert isinstance(res[1], torch.Tensor) | |||||
| assert res[0].shape == (1000, 32) | |||||
| assert res[1].shape == (100, 32) | |||||
| assert torch.all(res[0] == layer.node_reps[0]) | |||||
| assert torch.all(res[1] == layer.node_reps[1]) | |||||
| def test_input_layer_03(): | |||||
| if torch.cuda.device_count() == 0: | |||||
| pytest.skip('No CUDA devices on this host') | |||||
| d = _some_data() | |||||
| layer = InputLayer(d, 32) | |||||
| device = torch.device('cuda:0') | |||||
| layer = layer.to(device) | |||||
| print(list(layer.parameters())) | |||||
| # assert layer.device.type == 'cuda:0' | |||||
| assert layer.node_reps[0].device == device | |||||
| assert layer.node_reps[1].device == device | |||||