IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Przeglądaj źródła

Remove node_type_row/node_type_column as redundant in add_relation_type().

master
Stanislaw Adaszewski 4 lat temu
rodzic
commit
8767db601d
6 zmienionych plików z 115 dodań i 49 usunięć
  1. +16
    -22
      src/icosagon/data.py
  2. +8
    -8
      tests/icosagon/test_convlayer.py
  3. +85
    -12
      tests/icosagon/test_data.py
  4. +1
    -1
      tests/icosagon/test_declayer.py
  5. +0
    -1
      tests/icosagon/test_decode.py
  6. +5
    -5
      tests/icosagon/test_input.py

+ 16
- 22
src/icosagon/data.py Wyświetl plik

@@ -80,22 +80,12 @@ class RelationFamily(RelationFamilyBase):
self.relation_types = []
def add_relation_type(self,
name: str, node_type_row: int, node_type_column: int,
adjacency_matrix: torch.Tensor,
name: str, adjacency_matrix: torch.Tensor,
adjacency_matrix_backward: torch.Tensor = None) -> None:
name = str(name)
node_type_row = int(node_type_row)
node_type_column = int(node_type_column)
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column):
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family')
if node_type_row < 0 or node_type_row >= len(self.data.node_types):
raise ValueError('node_type_row outside of the valid range of node types')
if node_type_column < 0 or node_type_column >= len(self.data.node_types):
raise ValueError('node_type_column outside of the valid range of node types')
node_type_row = self.node_type_row
node_type_column = self.node_type_column
if adjacency_matrix is None and adjacency_matrix_backward is None:
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
@@ -104,13 +94,6 @@ class RelationFamily(RelationFamilyBase):
not isinstance(adjacency_matrix, torch.Tensor):
raise ValueError('adjacency_matrix must be a torch.Tensor')
# if isinstance(adjacency_matrix_backward, str) and \
# adjacency_matrix_backward == 'symmetric':
# if self.is_symmetric:
# adjacency_matrix_backward = None
# else:
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
if adjacency_matrix_backward is not None \
and not isinstance(adjacency_matrix_backward, torch.Tensor):
raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
@@ -154,7 +137,7 @@ class RelationFamily(RelationFamilyBase):
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.is_symmetric \
or self.node_type_row == self.node_type_column \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
@@ -167,7 +150,7 @@ class RelationFamily(RelationFamilyBase):
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.is_symmetric \
or self.node_type_row == self.node_type_column \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
@@ -195,6 +178,17 @@ class Data(object):
node_type_column: int, is_symmetric: bool,
decoder_class: Type = DEDICOMDecoder):
name = str(name)
node_type_row = int(node_type_row)
node_type_column = int(node_type_column)
is_symmetric = bool(is_symmetric)
if node_type_row < 0 or node_type_row >= len(self.node_types):
raise ValueError('node_type_row outside of the valid range of node types')
if node_type_column < 0 or node_type_column >= len(self.node_types):
raise ValueError('node_type_column outside of the valid range of node types')
fam = RelationFamily(self, name, node_type_row, node_type_column,
is_symmetric, decoder_class)
self.relation_families.append(fam)


+ 8
- 8
tests/icosagon/test_convlayer.py Wyświetl plik

@@ -26,19 +26,19 @@ def _some_data_with_interactions():
d.add_node_type('Drug', 100)
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
fam.add_relation_type('Target', 1, 0,
fam.add_relation_type('Target',
torch.rand((100, 1000), dtype=torch.float32).round())
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
fam.add_relation_type('Interaction', 0, 0,
fam.add_relation_type('Interaction',
_symmetric_random(1000, 1000))
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
fam.add_relation_type('Side Effect: Nausea', 1, 1,
fam.add_relation_type('Side Effect: Nausea',
_symmetric_random(100, 100))
fam.add_relation_type('Side Effect: Infertility', 1, 1,
fam.add_relation_type('Side Effect: Infertility',
_symmetric_random(100, 100))
fam.add_relation_type('Side Effect: Death', 1, 1,
fam.add_relation_type('Side Effect: Death',
_symmetric_random(100, 100))
return d
@@ -97,7 +97,7 @@ def test_decagon_layer_04():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
fam.add_relation_type('Dummy Relation', 0, 0,
fam.add_relation_type('Dummy Relation',
_symmetric_random(100, 100).to_sparse())
in_layer = OneHotInputLayer(d)
@@ -139,9 +139,9 @@ def test_decagon_layer_05():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
fam.add_relation_type('Dummy Relation 1', 0, 0,
fam.add_relation_type('Dummy Relation 1',
_symmetric_random(100, 100).to_sparse())
fam.add_relation_type('Dummy Relation 2', 0, 0,
fam.add_relation_type('Dummy Relation 2',
_symmetric_random(100, 100).to_sparse())
in_layer = OneHotInputLayer(d)


+ 85
- 12
tests/icosagon/test_data.py Wyświetl plik

@@ -4,11 +4,80 @@
#
from icosagon import Data
from icosagon.data import Data, \
_equal, \
RelationFamily
from icosagon.decode import DEDICOMDecoder
import torch
import pytest
def test_equal_01():
x = torch.rand((10, 10))
y = torch.rand((10, 10)).round().to_sparse()
assert torch.all(_equal(x, x))
assert torch.all(_equal(y, y))
with pytest.raises(ValueError):
_equal(x, y)
z = torch.rand((10, 10)).round().to_sparse()
assert not torch.all(_equal(y, z))
def test_relation_family_01():
d = Data()
d.add_node_type('Whatever', 10)
fam = RelationFamily(d, 'Dummy-Dummy', 0, 0, True, DEDICOMDecoder)
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', None, None)
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', 'bad-value', None)
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', None, 'bad-value')
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', torch.rand((5, 5)), None)
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', None, torch.rand((5, 5)))
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), torch.rand((10, 10)))
with pytest.raises(ValueError):
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), None)
def test_relation_family_02():
d = Data()
d.add_node_type('A', 10)
d.add_node_type('B', 5)
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder)
with pytest.raises(ValueError):
fam.add_relation_type('A-B', torch.rand((10, 5)).round(),
torch.rand((5, 10)).round())
def test_relation_family_03():
d = Data()
d.add_node_type('A', 10)
d.add_node_type('B', 5)
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder)
fam.add_relation_type('A-B', torch.rand((10, 5)).round())
assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \
fam.relation_types[0].adjacency_matrix_backward)
def test_data_01():
d = Data()
d.add_node_type('Gene', 1000)
@@ -17,14 +86,18 @@ def test_data_01():
dummy_1 = torch.zeros((1000, 100))
dummy_2 = torch.zeros((100, 100))
dummy_3 = torch.zeros((1000, 1000))
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
fam.add_relation_type('Target', 1, 0, dummy_0)
fam.add_relation_type('Target', dummy_0)
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
fam.add_relation_type('Interaction', 0, 0, dummy_3)
fam.add_relation_type('Interaction', dummy_3)
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2)
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2)
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2)
fam.add_relation_type('Side Effect: Nausea', dummy_2)
fam.add_relation_type('Side Effect: Infertility', dummy_2)
fam.add_relation_type('Side Effect: Death', dummy_2)
print(d)
@@ -40,19 +113,19 @@ def test_data_02():
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
with pytest.raises(ValueError):
fam.add_relation_type('Target', 1, 0, dummy_1)
fam.add_relation_type('Target', dummy_1)
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
with pytest.raises(ValueError):
fam.add_relation_type('Interaction', 0, 0, dummy_2)
fam.add_relation_type('Interaction', dummy_2)
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
with pytest.raises(ValueError):
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Nausea', dummy_3)
with pytest.raises(ValueError):
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Infertility', dummy_3)
with pytest.raises(ValueError):
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Death', dummy_3)
print(d)
@@ -62,5 +135,5 @@ def test_data_03():
d.add_node_type('Drug', 100)
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
with pytest.raises(ValueError):
fam.add_relation_type('Target', 1, 0, None)
fam.add_relation_type('Target', None)
print(d)

+ 1
- 1
tests/icosagon/test_declayer.py Wyświetl plik

@@ -22,7 +22,7 @@ def test_decode_layer_01():
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1', 0, 0,
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))


+ 0
- 1
tests/icosagon/test_decode.py Wyświetl plik

@@ -151,7 +151,6 @@ def test_is_inner_product_symmetric_01():
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)


+ 5
- 5
tests/icosagon/test_input.py Wyświetl plik

@@ -11,15 +11,15 @@ def _some_data():
d.add_node_type('Drug', 100)
fam = d.add_relation_family('Drug-Gene', 1, 0, False)
fam.add_relation_type('Target', 1, 0, torch.rand(100, 1000))
fam.add_relation_type('Target', torch.rand(100, 1000))
fam = d.add_relation_family('Gene-Gene', 0, 0, False)
fam.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000))
fam.add_relation_type('Interaction', torch.rand(1000, 1000))
fam = d.add_relation_family('Drug-Drug', 1, 1, False)
fam.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100))
fam.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100))
fam.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100))
fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100))
fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100))
fam.add_relation_type('Side Effect: Death', torch.rand(100, 100))
return d


Ładowanie…
Anuluj
Zapisz