from decagon_pytorch.layer import InputLayer, \ OneHotInputLayer, \ DecagonLayer from decagon_pytorch.data import Data import torch import pytest from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \ SparseMultiDGCA, \ DropoutGraphConvActivation def _some_data(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) d.add_relation_type('Target', 1, 0, None) d.add_relation_type('Interaction', 0, 0, None) d.add_relation_type('Side Effect: Nausea', 1, 1, None) d.add_relation_type('Side Effect: Infertility', 1, 1, None) d.add_relation_type('Side Effect: Death', 1, 1, None) return d def _some_data_with_interactions(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) d.add_relation_type('Target', 1, 0, torch.rand((100, 1000), dtype=torch.float32).round()) d.add_relation_type('Interaction', 0, 0, torch.rand((1000, 1000), dtype=torch.float32).round()) d.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand((100, 100), dtype=torch.float32).round()) d.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand((100, 100), dtype=torch.float32).round()) d.add_relation_type('Side Effect: Death', 1, 1, torch.rand((100, 100), dtype=torch.float32).round()) return d def test_input_layer_01(): d = _some_data() for output_dim in [32, 64, 128]: layer = InputLayer(d, output_dim) assert layer.output_dim[0] == output_dim assert len(layer.node_reps) == 2 assert layer.node_reps[0].shape == (1000, output_dim) assert layer.node_reps[1].shape == (100, output_dim) assert layer.data == d def test_input_layer_02(): d = _some_data() layer = InputLayer(d, 32) res = layer() assert isinstance(res[0], torch.Tensor) assert isinstance(res[1], torch.Tensor) assert res[0].shape == (1000, 32) assert res[1].shape == (100, 32) assert torch.all(res[0] == layer.node_reps[0]) assert torch.all(res[1] == layer.node_reps[1]) def test_input_layer_03(): if torch.cuda.device_count() == 0: pytest.skip('No CUDA devices on this host') d = _some_data() layer = InputLayer(d, 32) device = torch.device('cuda:0') layer = layer.to(device) print(list(layer.parameters())) # assert layer.device.type == 'cuda:0' assert layer.node_reps[0].device == device assert layer.node_reps[1].device == device def test_one_hot_input_layer_01(): d = _some_data() layer = OneHotInputLayer(d) assert layer.output_dim == [1000, 100] assert len(layer.node_reps) == 2 assert layer.node_reps[0].shape == (1000, 1000) assert layer.node_reps[1].shape == (100, 100) assert layer.data == d assert layer.is_sparse def test_one_hot_input_layer_02(): d = _some_data() layer = OneHotInputLayer(d) res = layer() assert isinstance(res[0], torch.Tensor) assert isinstance(res[1], torch.Tensor) assert res[0].shape == (1000, 1000) assert res[1].shape == (100, 100) assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense()) assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense()) def test_one_hot_input_layer_03(): if torch.cuda.device_count() == 0: pytest.skip('No CUDA devices on this host') d = _some_data() layer = OneHotInputLayer(d) device = torch.device('cuda:0') layer = layer.to(device) print(list(layer.parameters())) # assert layer.device.type == 'cuda:0' assert layer.node_reps[0].device == device assert layer.node_reps[1].device == device def test_decagon_layer_01(): d = _some_data_with_interactions() in_layer = InputLayer(d) d_layer = DecagonLayer(d, in_layer, output_dim=32) def test_decagon_layer_02(): d = _some_data_with_interactions() in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(d, in_layer, output_dim=32) _ = d_layer() # dummy call def test_decagon_layer_03(): d = _some_data_with_interactions() in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(d, in_layer, output_dim=32) assert d_layer.data == d assert d_layer.previous_layer == in_layer assert d_layer.input_dim == [ 1000, 100 ] assert not d_layer.is_sparse assert d_layer.keep_prob == 1. assert d_layer.rel_activation(0.5) == 0.5 x = torch.tensor([-1, 0, 0.5, 1]) assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all() assert len(d_layer.next_layer_repr) == 2 assert len(d_layer.next_layer_repr[0]) == 2 assert len(d_layer.next_layer_repr[1]) == 4 assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), d_layer.next_layer_repr[0])) assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), d_layer.next_layer_repr[1])) assert all(map(lambda a: a[0].output_dim == 32, d_layer.next_layer_repr[0])) assert all(map(lambda a: a[0].output_dim == 32, d_layer.next_layer_repr[1])) def test_decagon_layer_04(): # check if it is equivalent to MultiDGCA, as it should be d = Data() d.add_node_type('Dummy', 100) d.add_relation_type('Dummy Relation', 0, 0, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) in_layer = OneHotInputLayer(d) multi_dgca = SparseMultiDGCA([10], 32, [r.adjacency_matrix for r in d.relation_types[0, 0]], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(d, in_layer, output_dim=32, keep_prob=1., rel_activation=lambda x: x, layer_activation=lambda x: x) assert isinstance(d_layer.next_layer_repr[0][0][0], DropoutGraphConvActivation) weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight assert isinstance(weight, torch.Tensor) assert len(multi_dgca.sparse_dgca) == 1 assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation) multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight out_d_layer = d_layer() out_multi_dgca = multi_dgca(in_layer()) assert isinstance(out_d_layer, list) assert len(out_d_layer) == 1 assert torch.all(out_d_layer[0] == out_multi_dgca)