|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- from decagon_pytorch.layer import InputLayer, \
- OneHotInputLayer, \
- DecagonLayer
- from decagon_pytorch.data import Data
- import torch
- import pytest
- from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
- SparseMultiDGCA, \
- DropoutGraphConvActivation
-
-
- def _some_data():
- d = Data()
- d.add_node_type('Gene', 1000)
- d.add_node_type('Drug', 100)
- d.add_relation_type('Target', 1, 0, None)
- d.add_relation_type('Interaction', 0, 0, None)
- d.add_relation_type('Side Effect: Nausea', 1, 1, None)
- d.add_relation_type('Side Effect: Infertility', 1, 1, None)
- d.add_relation_type('Side Effect: Death', 1, 1, None)
- return d
-
-
- def _some_data_with_interactions():
- d = Data()
- d.add_node_type('Gene', 1000)
- d.add_node_type('Drug', 100)
- d.add_relation_type('Target', 1, 0,
- torch.rand((100, 1000), dtype=torch.float32).round())
- d.add_relation_type('Interaction', 0, 0,
- torch.rand((1000, 1000), dtype=torch.float32).round())
- d.add_relation_type('Side Effect: Nausea', 1, 1,
- torch.rand((100, 100), dtype=torch.float32).round())
- d.add_relation_type('Side Effect: Infertility', 1, 1,
- torch.rand((100, 100), dtype=torch.float32).round())
- d.add_relation_type('Side Effect: Death', 1, 1,
- torch.rand((100, 100), dtype=torch.float32).round())
- return d
-
-
- def test_input_layer_01():
- d = _some_data()
- for output_dim in [32, 64, 128]:
- layer = InputLayer(d, output_dim)
- assert layer.output_dim[0] == output_dim
- assert len(layer.node_reps) == 2
- assert layer.node_reps[0].shape == (1000, output_dim)
- assert layer.node_reps[1].shape == (100, output_dim)
- assert layer.data == d
-
-
- def test_input_layer_02():
- d = _some_data()
- layer = InputLayer(d, 32)
- res = layer()
- assert isinstance(res[0], torch.Tensor)
- assert isinstance(res[1], torch.Tensor)
- assert res[0].shape == (1000, 32)
- assert res[1].shape == (100, 32)
- assert torch.all(res[0] == layer.node_reps[0])
- assert torch.all(res[1] == layer.node_reps[1])
-
-
- def test_input_layer_03():
- if torch.cuda.device_count() == 0:
- pytest.skip('No CUDA devices on this host')
- d = _some_data()
- layer = InputLayer(d, 32)
- device = torch.device('cuda:0')
- layer = layer.to(device)
- print(list(layer.parameters()))
- # assert layer.device.type == 'cuda:0'
- assert layer.node_reps[0].device == device
- assert layer.node_reps[1].device == device
-
-
- def test_one_hot_input_layer_01():
- d = _some_data()
- layer = OneHotInputLayer(d)
- assert layer.output_dim == [1000, 100]
- assert len(layer.node_reps) == 2
- assert layer.node_reps[0].shape == (1000, 1000)
- assert layer.node_reps[1].shape == (100, 100)
- assert layer.data == d
- assert layer.is_sparse
-
-
- def test_one_hot_input_layer_02():
- d = _some_data()
- layer = OneHotInputLayer(d)
- res = layer()
- assert isinstance(res[0], torch.Tensor)
- assert isinstance(res[1], torch.Tensor)
- assert res[0].shape == (1000, 1000)
- assert res[1].shape == (100, 100)
- assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
- assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
-
-
- def test_one_hot_input_layer_03():
- if torch.cuda.device_count() == 0:
- pytest.skip('No CUDA devices on this host')
- d = _some_data()
- layer = OneHotInputLayer(d)
- device = torch.device('cuda:0')
- layer = layer.to(device)
- print(list(layer.parameters()))
- # assert layer.device.type == 'cuda:0'
- assert layer.node_reps[0].device == device
- assert layer.node_reps[1].device == device
-
-
- def test_decagon_layer_01():
- d = _some_data_with_interactions()
- in_layer = InputLayer(d)
- d_layer = DecagonLayer(d, in_layer, output_dim=32)
-
-
- def test_decagon_layer_02():
- d = _some_data_with_interactions()
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(d, in_layer, output_dim=32)
- _ = d_layer() # dummy call
-
-
- def test_decagon_layer_03():
- d = _some_data_with_interactions()
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(d, in_layer, output_dim=32)
- assert d_layer.data == d
- assert d_layer.previous_layer == in_layer
- assert d_layer.input_dim == [ 1000, 100 ]
- assert not d_layer.is_sparse
- assert d_layer.keep_prob == 1.
- assert d_layer.rel_activation(0.5) == 0.5
- x = torch.tensor([-1, 0, 0.5, 1])
- assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
- assert len(d_layer.next_layer_repr) == 2
- assert len(d_layer.next_layer_repr[0]) == 2
- assert len(d_layer.next_layer_repr[1]) == 4
- assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
- d_layer.next_layer_repr[0]))
- assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
- d_layer.next_layer_repr[1]))
- assert all(map(lambda a: a[0].output_dim == 32,
- d_layer.next_layer_repr[0]))
- assert all(map(lambda a: a[0].output_dim == 32,
- d_layer.next_layer_repr[1]))
-
-
- def test_decagon_layer_04():
- # check if it is equivalent to MultiDGCA, as it should be
-
- d = Data()
- d.add_node_type('Dummy', 100)
- d.add_relation_type('Dummy Relation', 0, 0,
- torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
-
- in_layer = OneHotInputLayer(d)
-
- multi_dgca = SparseMultiDGCA([10], 32,
- [r.adjacency_matrix for r in d.relation_types[0, 0]],
- keep_prob=1., activation=lambda x: x)
-
- d_layer = DecagonLayer(d, in_layer, output_dim=32,
- keep_prob=1., rel_activation=lambda x: x,
- layer_activation=lambda x: x)
-
- assert isinstance(d_layer.next_layer_repr[0][0][0],
- DropoutGraphConvActivation)
-
- weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight
- assert isinstance(weight, torch.Tensor)
-
- assert len(multi_dgca.sparse_dgca) == 1
- assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
-
- multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
-
- out_d_layer = d_layer()
- out_multi_dgca = multi_dgca(in_layer())
-
- assert isinstance(out_d_layer, list)
- assert len(out_d_layer) == 1
-
- assert torch.all(out_d_layer[0] == out_multi_dgca)
|