|
|
@@ -1,4 +1,5 @@ |
|
|
|
from decagon_pytorch.layer import InputLayer
|
|
|
|
from decagon_pytorch.layer import InputLayer, \
|
|
|
|
DecagonLayer
|
|
|
|
from decagon_pytorch.data import Data
|
|
|
|
import torch
|
|
|
|
import pytest
|
|
|
@@ -16,11 +17,28 @@ def _some_data(): |
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
def _some_data_with_interactions():
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Gene', 1000)
|
|
|
|
d.add_node_type('Drug', 100)
|
|
|
|
d.add_relation_type('Target', 1, 0,
|
|
|
|
torch.rand((100, 1000), dtype=torch.float32).round())
|
|
|
|
d.add_relation_type('Interaction', 0, 0,
|
|
|
|
torch.rand((1000, 1000), dtype=torch.float32).round())
|
|
|
|
d.add_relation_type('Side Effect: Nausea', 1, 1,
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round())
|
|
|
|
d.add_relation_type('Side Effect: Infertility', 1, 1,
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round())
|
|
|
|
d.add_relation_type('Side Effect: Death', 1, 1,
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round())
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
def test_input_layer_01():
|
|
|
|
d = _some_data()
|
|
|
|
for dimensionality in [32, 64, 128]:
|
|
|
|
layer = InputLayer(d, dimensionality)
|
|
|
|
assert layer.dimensionality == dimensionality
|
|
|
|
assert layer.dimensionality[0] == dimensionality
|
|
|
|
assert len(layer.node_reps) == 2
|
|
|
|
assert layer.node_reps[0].shape == (1000, dimensionality)
|
|
|
|
assert layer.node_reps[1].shape == (100, dimensionality)
|
|
|
@@ -50,3 +68,10 @@ def test_input_layer_03(): |
|
|
|
# assert layer.device.type == 'cuda:0'
|
|
|
|
assert layer.node_reps[0].device == device
|
|
|
|
assert layer.node_reps[1].device == device
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip()
|
|
|
|
def test_decagon_layer_01():
|
|
|
|
d = _some_data_with_interactions()
|
|
|
|
in_layer = InputLayer(d)
|
|
|
|
d_layer = DecagonLayer(in_layer, output_dim=32)
|