From 8767db601dcb19804143d98055c33580ac58c709 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 12 Jun 2020 16:14:03 +0200 Subject: [PATCH] Remove node_type_row/node_type_column as redundant in add_relation_type(). --- src/icosagon/data.py | 38 ++++++------- tests/icosagon/test_convlayer.py | 16 +++--- tests/icosagon/test_data.py | 97 ++++++++++++++++++++++++++++---- tests/icosagon/test_declayer.py | 2 +- tests/icosagon/test_decode.py | 1 - tests/icosagon/test_input.py | 10 ++-- 6 files changed, 115 insertions(+), 49 deletions(-) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 53bd769..0181b7d 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -80,22 +80,12 @@ class RelationFamily(RelationFamilyBase): self.relation_types = [] def add_relation_type(self, - name: str, node_type_row: int, node_type_column: int, - adjacency_matrix: torch.Tensor, + name: str, adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None) -> None: name = str(name) - node_type_row = int(node_type_row) - node_type_column = int(node_type_column) - - if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): - raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') - - if node_type_row < 0 or node_type_row >= len(self.data.node_types): - raise ValueError('node_type_row outside of the valid range of node types') - - if node_type_column < 0 or node_type_column >= len(self.data.node_types): - raise ValueError('node_type_column outside of the valid range of node types') + node_type_row = self.node_type_row + node_type_column = self.node_type_column if adjacency_matrix is None and adjacency_matrix_backward is None: raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') @@ -104,13 +94,6 @@ class RelationFamily(RelationFamilyBase): not isinstance(adjacency_matrix, torch.Tensor): raise ValueError('adjacency_matrix must be a torch.Tensor') - # if isinstance(adjacency_matrix_backward, str) and \ - # adjacency_matrix_backward == 'symmetric': - # if self.is_symmetric: - # adjacency_matrix_backward = None - # else: - # adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - if adjacency_matrix_backward is not None \ and not isinstance(adjacency_matrix_backward, torch.Tensor): raise ValueError('adjacency_matrix_backward must be a torch.Tensor') @@ -154,7 +137,7 @@ class RelationFamily(RelationFamilyBase): s += '\n - %s%s' % (r.name, ' (two-way)' \ if (r.adjacency_matrix is not None \ and r.adjacency_matrix_backward is not None) \ - or self.is_symmetric \ + or self.node_type_row == self.node_type_column \ else '%s <- %s' % (self.node_name(self.node_type_row), self.node_name(self.node_type_column))) @@ -167,7 +150,7 @@ class RelationFamily(RelationFamilyBase): s += '\n - %s%s' % (r.name, ' (two-way)' \ if (r.adjacency_matrix is not None \ and r.adjacency_matrix_backward is not None) \ - or self.is_symmetric \ + or self.node_type_row == self.node_type_column \ else '%s <- %s' % (self.node_name(self.node_type_row), self.node_name(self.node_type_column))) @@ -195,6 +178,17 @@ class Data(object): node_type_column: int, is_symmetric: bool, decoder_class: Type = DEDICOMDecoder): + name = str(name) + node_type_row = int(node_type_row) + node_type_column = int(node_type_column) + is_symmetric = bool(is_symmetric) + + if node_type_row < 0 or node_type_row >= len(self.node_types): + raise ValueError('node_type_row outside of the valid range of node types') + + if node_type_column < 0 or node_type_column >= len(self.node_types): + raise ValueError('node_type_column outside of the valid range of node types') + fam = RelationFamily(self, name, node_type_row, node_type_column, is_symmetric, decoder_class) self.relation_families.append(fam) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index 11e8ffc..145ad95 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -26,19 +26,19 @@ def _some_data_with_interactions(): d.add_node_type('Drug', 100) fam = d.add_relation_family('Drug-Gene', 1, 0, True) - fam.add_relation_type('Target', 1, 0, + fam.add_relation_type('Target', torch.rand((100, 1000), dtype=torch.float32).round()) fam = d.add_relation_family('Gene-Gene', 0, 0, True) - fam.add_relation_type('Interaction', 0, 0, + fam.add_relation_type('Interaction', _symmetric_random(1000, 1000)) fam = d.add_relation_family('Drug-Drug', 1, 1, True) - fam.add_relation_type('Side Effect: Nausea', 1, 1, + fam.add_relation_type('Side Effect: Nausea', _symmetric_random(100, 100)) - fam.add_relation_type('Side Effect: Infertility', 1, 1, + fam.add_relation_type('Side Effect: Infertility', _symmetric_random(100, 100)) - fam.add_relation_type('Side Effect: Death', 1, 1, + fam.add_relation_type('Side Effect: Death', _symmetric_random(100, 100)) return d @@ -97,7 +97,7 @@ def test_decagon_layer_04(): d = Data() d.add_node_type('Dummy', 100) fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) - fam.add_relation_type('Dummy Relation', 0, 0, + fam.add_relation_type('Dummy Relation', _symmetric_random(100, 100).to_sparse()) in_layer = OneHotInputLayer(d) @@ -139,9 +139,9 @@ def test_decagon_layer_05(): d = Data() d.add_node_type('Dummy', 100) fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) - fam.add_relation_type('Dummy Relation 1', 0, 0, + fam.add_relation_type('Dummy Relation 1', _symmetric_random(100, 100).to_sparse()) - fam.add_relation_type('Dummy Relation 2', 0, 0, + fam.add_relation_type('Dummy Relation 2', _symmetric_random(100, 100).to_sparse()) in_layer = OneHotInputLayer(d) diff --git a/tests/icosagon/test_data.py b/tests/icosagon/test_data.py index d67c1c1..08ded89 100644 --- a/tests/icosagon/test_data.py +++ b/tests/icosagon/test_data.py @@ -4,11 +4,80 @@ # -from icosagon import Data +from icosagon.data import Data, \ + _equal, \ + RelationFamily +from icosagon.decode import DEDICOMDecoder import torch import pytest +def test_equal_01(): + x = torch.rand((10, 10)) + y = torch.rand((10, 10)).round().to_sparse() + + assert torch.all(_equal(x, x)) + assert torch.all(_equal(y, y)) + with pytest.raises(ValueError): + _equal(x, y) + + z = torch.rand((10, 10)).round().to_sparse() + assert not torch.all(_equal(y, z)) + + +def test_relation_family_01(): + d = Data() + d.add_node_type('Whatever', 10) + + fam = RelationFamily(d, 'Dummy-Dummy', 0, 0, True, DEDICOMDecoder) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', None, None) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', 'bad-value', None) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', None, 'bad-value') + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', torch.rand((5, 5)), None) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', None, torch.rand((5, 5))) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), torch.rand((10, 10))) + + with pytest.raises(ValueError): + fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), None) + + +def test_relation_family_02(): + d = Data() + d.add_node_type('A', 10) + d.add_node_type('B', 5) + + fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) + + with pytest.raises(ValueError): + fam.add_relation_type('A-B', torch.rand((10, 5)).round(), + torch.rand((5, 10)).round()) + + +def test_relation_family_03(): + d = Data() + d.add_node_type('A', 10) + d.add_node_type('B', 5) + + fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) + + fam.add_relation_type('A-B', torch.rand((10, 5)).round()) + + assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \ + fam.relation_types[0].adjacency_matrix_backward) + + def test_data_01(): d = Data() d.add_node_type('Gene', 1000) @@ -17,14 +86,18 @@ def test_data_01(): dummy_1 = torch.zeros((1000, 100)) dummy_2 = torch.zeros((100, 100)) dummy_3 = torch.zeros((1000, 1000)) + fam = d.add_relation_family('Drug-Gene', 1, 0, True) - fam.add_relation_type('Target', 1, 0, dummy_0) + fam.add_relation_type('Target', dummy_0) + fam = d.add_relation_family('Gene-Gene', 0, 0, True) - fam.add_relation_type('Interaction', 0, 0, dummy_3) + fam.add_relation_type('Interaction', dummy_3) + fam = d.add_relation_family('Drug-Drug', 1, 1, True) - fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) - fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) - fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2) + fam.add_relation_type('Side Effect: Nausea', dummy_2) + fam.add_relation_type('Side Effect: Infertility', dummy_2) + fam.add_relation_type('Side Effect: Death', dummy_2) + print(d) @@ -40,19 +113,19 @@ def test_data_02(): fam = d.add_relation_family('Drug-Gene', 1, 0, True) with pytest.raises(ValueError): - fam.add_relation_type('Target', 1, 0, dummy_1) + fam.add_relation_type('Target', dummy_1) fam = d.add_relation_family('Gene-Gene', 0, 0, True) with pytest.raises(ValueError): - fam.add_relation_type('Interaction', 0, 0, dummy_2) + fam.add_relation_type('Interaction', dummy_2) fam = d.add_relation_family('Drug-Drug', 1, 1, True) with pytest.raises(ValueError): - fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Nausea', dummy_3) with pytest.raises(ValueError): - fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Infertility', dummy_3) with pytest.raises(ValueError): - fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Death', dummy_3) print(d) @@ -62,5 +135,5 @@ def test_data_03(): d.add_node_type('Drug', 100) fam = d.add_relation_family('Drug-Gene', 1, 0, True) with pytest.raises(ValueError): - fam.add_relation_type('Target', 1, 0, None) + fam.add_relation_type('Target', None) print(d) diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index bab9a25..14b531b 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -22,7 +22,7 @@ def test_decode_layer_01(): d.add_node_type('Dummy', 100) fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) - fam.add_relation_type('Dummy Relation 1', 0, 0, + fam.add_relation_type('Dummy Relation 1', torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py index ee0eaf2..f52a3c7 100644 --- a/tests/icosagon/test_decode.py +++ b/tests/icosagon/test_decode.py @@ -151,7 +151,6 @@ def test_is_inner_product_symmetric_01(): res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] - assert isinstance(res_1, list) assert isinstance(res_2, list) diff --git a/tests/icosagon/test_input.py b/tests/icosagon/test_input.py index 1ac2e3d..4f50912 100644 --- a/tests/icosagon/test_input.py +++ b/tests/icosagon/test_input.py @@ -11,15 +11,15 @@ def _some_data(): d.add_node_type('Drug', 100) fam = d.add_relation_family('Drug-Gene', 1, 0, False) - fam.add_relation_type('Target', 1, 0, torch.rand(100, 1000)) + fam.add_relation_type('Target', torch.rand(100, 1000)) fam = d.add_relation_family('Gene-Gene', 0, 0, False) - fam.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000)) + fam.add_relation_type('Interaction', torch.rand(1000, 1000)) fam = d.add_relation_family('Drug-Drug', 1, 1, False) - fam.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100)) - fam.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100)) - fam.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100)) + fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100)) + fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100)) + fam.add_relation_type('Side Effect: Death', torch.rand(100, 100)) return d