|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from decagon_pytorch.layer import InputLayer
- from decagon_pytorch.data import Data
- import torch
- import pytest
-
-
- def _some_data():
- d = Data()
- d.add_node_type('Gene', 1000)
- d.add_node_type('Drug', 100)
- d.add_relation_type('Target', 1, 0, None)
- d.add_relation_type('Interaction', 0, 0, None)
- d.add_relation_type('Side Effect: Nausea', 1, 1, None)
- d.add_relation_type('Side Effect: Infertility', 1, 1, None)
- d.add_relation_type('Side Effect: Death', 1, 1, None)
- return d
-
-
- def test_input_layer_01():
- d = _some_data()
- for dimensionality in [32, 64, 128]:
- layer = InputLayer(d, dimensionality)
- assert layer.dimensionality == dimensionality
- assert len(layer.node_reps) == 2
- assert layer.node_reps[0].shape == (1000, dimensionality)
- assert layer.node_reps[1].shape == (100, dimensionality)
- assert layer.data == d
-
-
- def test_input_layer_02():
- d = _some_data()
- layer = InputLayer(d, 32)
- res = layer()
- assert isinstance(res[0], torch.Tensor)
- assert isinstance(res[1], torch.Tensor)
- assert res[0].shape == (1000, 32)
- assert res[1].shape == (100, 32)
- assert torch.all(res[0] == layer.node_reps[0])
- assert torch.all(res[1] == layer.node_reps[1])
-
-
- def test_input_layer_03():
- if torch.cuda.device_count() == 0:
- pytest.skip('No CUDA devices on this host')
- d = _some_data()
- layer = InputLayer(d, 32)
- device = torch.device('cuda:0')
- layer = layer.to(device)
- print(list(layer.parameters()))
- # assert layer.device.type == 'cuda:0'
- assert layer.node_reps[0].device == device
- assert layer.node_reps[1].device == device
|