| @@ -0,0 +1,60 @@ | |||
| # | |||
| # This module implements a single layer of the Decagon | |||
| # model. This is going to be already quite complex, as | |||
| # we will be using all the graph convolutional building | |||
| # blocks. | |||
| # | |||
| # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \ | |||
| # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k)) | |||
| # | |||
| # N{r}^{i} - set of neighbors of node i under relation r | |||
| # W_{r}^(k) - relation-type specific weight matrix | |||
| # h_{i}^(k) - hidden state of node i in layer k | |||
| # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality | |||
| # of the representation in k-th layer | |||
| # ϕ - activation function | |||
| # c_{r}^{ij} - normalization constants | |||
| # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||
| # c_{r}^{i} - normalization constants | |||
| # c_{r}^{i} = 1/|N_{r}^{i}| | |||
| # | |||
| import torch | |||
| class InputLayer(torch.nn.Module): | |||
| def __init__(self, data, dimensionality=32, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.data = data | |||
| self.dimensionality = dimensionality | |||
| self.node_reps = None | |||
| self.build() | |||
| def build(self): | |||
| self.node_reps = [] | |||
| for i, nt in enumerate(self.data.node_types): | |||
| reps = torch.rand(nt.count, self.dimensionality) | |||
| reps = torch.nn.Parameter(reps) | |||
| self.register_parameter('node_reps[%d]' % i, reps) | |||
| self.node_reps.append(reps) | |||
| def forward(self): | |||
| return self.node_reps | |||
| def __repr__(self): | |||
| s = '' | |||
| s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality | |||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||
| for nt in self.data.node_types: | |||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
| return s.strip() | |||
| class DecagonLayer(torch.nn.Module): | |||
| def __init__(self, data, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.data = data | |||
| def __call__(self, previous_layer): | |||
| pass | |||
| @@ -1,5 +1,4 @@ | |||
| from decagon_pytorch.data import Data | |||
| from decagon_pytorch.decode import DEDICOMDecoder | |||
| def test_data(): | |||
| @@ -0,0 +1,52 @@ | |||
| from decagon_pytorch.layer import InputLayer | |||
| from decagon_pytorch.data import Data | |||
| import torch | |||
| import pytest | |||
| def _some_data(): | |||
| d = Data() | |||
| d.add_node_type('Gene', 1000) | |||
| d.add_node_type('Drug', 100) | |||
| d.add_relation_type('Target', 1, 0, None) | |||
| d.add_relation_type('Interaction', 0, 0, None) | |||
| d.add_relation_type('Side Effect: Nausea', 1, 1, None) | |||
| d.add_relation_type('Side Effect: Infertility', 1, 1, None) | |||
| d.add_relation_type('Side Effect: Death', 1, 1, None) | |||
| return d | |||
| def test_input_layer_01(): | |||
| d = _some_data() | |||
| for dimensionality in [32, 64, 128]: | |||
| layer = InputLayer(d, dimensionality) | |||
| assert layer.dimensionality == dimensionality | |||
| assert len(layer.node_reps) == 2 | |||
| assert layer.node_reps[0].shape == (1000, dimensionality) | |||
| assert layer.node_reps[1].shape == (100, dimensionality) | |||
| assert layer.data == d | |||
| def test_input_layer_02(): | |||
| d = _some_data() | |||
| layer = InputLayer(d, 32) | |||
| res = layer() | |||
| assert isinstance(res[0], torch.Tensor) | |||
| assert isinstance(res[1], torch.Tensor) | |||
| assert res[0].shape == (1000, 32) | |||
| assert res[1].shape == (100, 32) | |||
| assert torch.all(res[0] == layer.node_reps[0]) | |||
| assert torch.all(res[1] == layer.node_reps[1]) | |||
| def test_input_layer_03(): | |||
| if torch.cuda.device_count() == 0: | |||
| pytest.skip('No CUDA devices on this host') | |||
| d = _some_data() | |||
| layer = InputLayer(d, 32) | |||
| device = torch.device('cuda:0') | |||
| layer = layer.to(device) | |||
| print(list(layer.parameters())) | |||
| # assert layer.device.type == 'cuda:0' | |||
| assert layer.node_reps[0].device == device | |||
| assert layer.node_reps[1].device == device | |||