@@ -80,22 +80,12 @@ class RelationFamily(RelationFamilyBase): | |||||
self.relation_types = [] | self.relation_types = [] | ||||
def add_relation_type(self, | def add_relation_type(self, | ||||
name: str, node_type_row: int, node_type_column: int, | |||||
adjacency_matrix: torch.Tensor, | |||||
name: str, adjacency_matrix: torch.Tensor, | |||||
adjacency_matrix_backward: torch.Tensor = None) -> None: | adjacency_matrix_backward: torch.Tensor = None) -> None: | ||||
name = str(name) | name = str(name) | ||||
node_type_row = int(node_type_row) | |||||
node_type_column = int(node_type_column) | |||||
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||||
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | |||||
if node_type_row < 0 or node_type_row >= len(self.data.node_types): | |||||
raise ValueError('node_type_row outside of the valid range of node types') | |||||
if node_type_column < 0 or node_type_column >= len(self.data.node_types): | |||||
raise ValueError('node_type_column outside of the valid range of node types') | |||||
node_type_row = self.node_type_row | |||||
node_type_column = self.node_type_column | |||||
if adjacency_matrix is None and adjacency_matrix_backward is None: | if adjacency_matrix is None and adjacency_matrix_backward is None: | ||||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | ||||
@@ -104,13 +94,6 @@ class RelationFamily(RelationFamilyBase): | |||||
not isinstance(adjacency_matrix, torch.Tensor): | not isinstance(adjacency_matrix, torch.Tensor): | ||||
raise ValueError('adjacency_matrix must be a torch.Tensor') | raise ValueError('adjacency_matrix must be a torch.Tensor') | ||||
# if isinstance(adjacency_matrix_backward, str) and \ | |||||
# adjacency_matrix_backward == 'symmetric': | |||||
# if self.is_symmetric: | |||||
# adjacency_matrix_backward = None | |||||
# else: | |||||
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
if adjacency_matrix_backward is not None \ | if adjacency_matrix_backward is not None \ | ||||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | and not isinstance(adjacency_matrix_backward, torch.Tensor): | ||||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | ||||
@@ -154,7 +137,7 @@ class RelationFamily(RelationFamilyBase): | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | s += '\n - %s%s' % (r.name, ' (two-way)' \ | ||||
if (r.adjacency_matrix is not None \ | if (r.adjacency_matrix is not None \ | ||||
and r.adjacency_matrix_backward is not None) \ | and r.adjacency_matrix_backward is not None) \ | ||||
or self.is_symmetric \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | else '%s <- %s' % (self.node_name(self.node_type_row), | ||||
self.node_name(self.node_type_column))) | self.node_name(self.node_type_column))) | ||||
@@ -167,7 +150,7 @@ class RelationFamily(RelationFamilyBase): | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | s += '\n - %s%s' % (r.name, ' (two-way)' \ | ||||
if (r.adjacency_matrix is not None \ | if (r.adjacency_matrix is not None \ | ||||
and r.adjacency_matrix_backward is not None) \ | and r.adjacency_matrix_backward is not None) \ | ||||
or self.is_symmetric \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | else '%s <- %s' % (self.node_name(self.node_type_row), | ||||
self.node_name(self.node_type_column))) | self.node_name(self.node_type_column))) | ||||
@@ -195,6 +178,17 @@ class Data(object): | |||||
node_type_column: int, is_symmetric: bool, | node_type_column: int, is_symmetric: bool, | ||||
decoder_class: Type = DEDICOMDecoder): | decoder_class: Type = DEDICOMDecoder): | ||||
name = str(name) | |||||
node_type_row = int(node_type_row) | |||||
node_type_column = int(node_type_column) | |||||
is_symmetric = bool(is_symmetric) | |||||
if node_type_row < 0 or node_type_row >= len(self.node_types): | |||||
raise ValueError('node_type_row outside of the valid range of node types') | |||||
if node_type_column < 0 or node_type_column >= len(self.node_types): | |||||
raise ValueError('node_type_column outside of the valid range of node types') | |||||
fam = RelationFamily(self, name, node_type_row, node_type_column, | fam = RelationFamily(self, name, node_type_row, node_type_column, | ||||
is_symmetric, decoder_class) | is_symmetric, decoder_class) | ||||
self.relation_families.append(fam) | self.relation_families.append(fam) | ||||
@@ -26,19 +26,19 @@ def _some_data_with_interactions(): | |||||
d.add_node_type('Drug', 100) | d.add_node_type('Drug', 100) | ||||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | ||||
fam.add_relation_type('Target', 1, 0, | |||||
fam.add_relation_type('Target', | |||||
torch.rand((100, 1000), dtype=torch.float32).round()) | torch.rand((100, 1000), dtype=torch.float32).round()) | ||||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | fam = d.add_relation_family('Gene-Gene', 0, 0, True) | ||||
fam.add_relation_type('Interaction', 0, 0, | |||||
fam.add_relation_type('Interaction', | |||||
_symmetric_random(1000, 1000)) | _symmetric_random(1000, 1000)) | ||||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | fam = d.add_relation_family('Drug-Drug', 1, 1, True) | ||||
fam.add_relation_type('Side Effect: Nausea', 1, 1, | |||||
fam.add_relation_type('Side Effect: Nausea', | |||||
_symmetric_random(100, 100)) | _symmetric_random(100, 100)) | ||||
fam.add_relation_type('Side Effect: Infertility', 1, 1, | |||||
fam.add_relation_type('Side Effect: Infertility', | |||||
_symmetric_random(100, 100)) | _symmetric_random(100, 100)) | ||||
fam.add_relation_type('Side Effect: Death', 1, 1, | |||||
fam.add_relation_type('Side Effect: Death', | |||||
_symmetric_random(100, 100)) | _symmetric_random(100, 100)) | ||||
return d | return d | ||||
@@ -97,7 +97,7 @@ def test_decagon_layer_04(): | |||||
d = Data() | d = Data() | ||||
d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | ||||
fam.add_relation_type('Dummy Relation', 0, 0, | |||||
fam.add_relation_type('Dummy Relation', | |||||
_symmetric_random(100, 100).to_sparse()) | _symmetric_random(100, 100).to_sparse()) | ||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
@@ -139,9 +139,9 @@ def test_decagon_layer_05(): | |||||
d = Data() | d = Data() | ||||
d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | ||||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||||
fam.add_relation_type('Dummy Relation 1', | |||||
_symmetric_random(100, 100).to_sparse()) | _symmetric_random(100, 100).to_sparse()) | ||||
fam.add_relation_type('Dummy Relation 2', 0, 0, | |||||
fam.add_relation_type('Dummy Relation 2', | |||||
_symmetric_random(100, 100).to_sparse()) | _symmetric_random(100, 100).to_sparse()) | ||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
@@ -4,11 +4,80 @@ | |||||
# | # | ||||
from icosagon import Data | |||||
from icosagon.data import Data, \ | |||||
_equal, \ | |||||
RelationFamily | |||||
from icosagon.decode import DEDICOMDecoder | |||||
import torch | import torch | ||||
import pytest | import pytest | ||||
def test_equal_01(): | |||||
x = torch.rand((10, 10)) | |||||
y = torch.rand((10, 10)).round().to_sparse() | |||||
assert torch.all(_equal(x, x)) | |||||
assert torch.all(_equal(y, y)) | |||||
with pytest.raises(ValueError): | |||||
_equal(x, y) | |||||
z = torch.rand((10, 10)).round().to_sparse() | |||||
assert not torch.all(_equal(y, z)) | |||||
def test_relation_family_01(): | |||||
d = Data() | |||||
d.add_node_type('Whatever', 10) | |||||
fam = RelationFamily(d, 'Dummy-Dummy', 0, 0, True, DEDICOMDecoder) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', None, None) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', 'bad-value', None) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', None, 'bad-value') | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', torch.rand((5, 5)), None) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', None, torch.rand((5, 5))) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), torch.rand((10, 10))) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), None) | |||||
def test_relation_family_02(): | |||||
d = Data() | |||||
d.add_node_type('A', 10) | |||||
d.add_node_type('B', 5) | |||||
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) | |||||
with pytest.raises(ValueError): | |||||
fam.add_relation_type('A-B', torch.rand((10, 5)).round(), | |||||
torch.rand((5, 10)).round()) | |||||
def test_relation_family_03(): | |||||
d = Data() | |||||
d.add_node_type('A', 10) | |||||
d.add_node_type('B', 5) | |||||
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) | |||||
fam.add_relation_type('A-B', torch.rand((10, 5)).round()) | |||||
assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \ | |||||
fam.relation_types[0].adjacency_matrix_backward) | |||||
def test_data_01(): | def test_data_01(): | ||||
d = Data() | d = Data() | ||||
d.add_node_type('Gene', 1000) | d.add_node_type('Gene', 1000) | ||||
@@ -17,14 +86,18 @@ def test_data_01(): | |||||
dummy_1 = torch.zeros((1000, 100)) | dummy_1 = torch.zeros((1000, 100)) | ||||
dummy_2 = torch.zeros((100, 100)) | dummy_2 = torch.zeros((100, 100)) | ||||
dummy_3 = torch.zeros((1000, 1000)) | dummy_3 = torch.zeros((1000, 1000)) | ||||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | ||||
fam.add_relation_type('Target', 1, 0, dummy_0) | |||||
fam.add_relation_type('Target', dummy_0) | |||||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | fam = d.add_relation_family('Gene-Gene', 0, 0, True) | ||||
fam.add_relation_type('Interaction', 0, 0, dummy_3) | |||||
fam.add_relation_type('Interaction', dummy_3) | |||||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | fam = d.add_relation_family('Drug-Drug', 1, 1, True) | ||||
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) | |||||
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) | |||||
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2) | |||||
fam.add_relation_type('Side Effect: Nausea', dummy_2) | |||||
fam.add_relation_type('Side Effect: Infertility', dummy_2) | |||||
fam.add_relation_type('Side Effect: Death', dummy_2) | |||||
print(d) | print(d) | ||||
@@ -40,19 +113,19 @@ def test_data_02(): | |||||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Target', 1, 0, dummy_1) | |||||
fam.add_relation_type('Target', dummy_1) | |||||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | fam = d.add_relation_family('Gene-Gene', 0, 0, True) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Interaction', 0, 0, dummy_2) | |||||
fam.add_relation_type('Interaction', dummy_2) | |||||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | fam = d.add_relation_family('Drug-Drug', 1, 1, True) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) | |||||
fam.add_relation_type('Side Effect: Nausea', dummy_3) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) | |||||
fam.add_relation_type('Side Effect: Infertility', dummy_3) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3) | |||||
fam.add_relation_type('Side Effect: Death', dummy_3) | |||||
print(d) | print(d) | ||||
@@ -62,5 +135,5 @@ def test_data_03(): | |||||
d.add_node_type('Drug', 100) | d.add_node_type('Drug', 100) | ||||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
fam.add_relation_type('Target', 1, 0, None) | |||||
fam.add_relation_type('Target', None) | |||||
print(d) | print(d) |
@@ -22,7 +22,7 @@ def test_decode_layer_01(): | |||||
d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | ||||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||||
fam.add_relation_type('Dummy Relation 1', | |||||
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | ||||
prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | ||||
@@ -151,7 +151,6 @@ def test_is_inner_product_symmetric_01(): | |||||
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | ||||
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | ||||
assert isinstance(res_1, list) | assert isinstance(res_1, list) | ||||
assert isinstance(res_2, list) | assert isinstance(res_2, list) | ||||
@@ -11,15 +11,15 @@ def _some_data(): | |||||
d.add_node_type('Drug', 100) | d.add_node_type('Drug', 100) | ||||
fam = d.add_relation_family('Drug-Gene', 1, 0, False) | fam = d.add_relation_family('Drug-Gene', 1, 0, False) | ||||
fam.add_relation_type('Target', 1, 0, torch.rand(100, 1000)) | |||||
fam.add_relation_type('Target', torch.rand(100, 1000)) | |||||
fam = d.add_relation_family('Gene-Gene', 0, 0, False) | fam = d.add_relation_family('Gene-Gene', 0, 0, False) | ||||
fam.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000)) | |||||
fam.add_relation_type('Interaction', torch.rand(1000, 1000)) | |||||
fam = d.add_relation_family('Drug-Drug', 1, 1, False) | fam = d.add_relation_family('Drug-Drug', 1, 1, False) | ||||
fam.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100)) | |||||
fam.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100)) | |||||
fam.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100)) | |||||
fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100)) | |||||
fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100)) | |||||
fam.add_relation_type('Side Effect: Death', torch.rand(100, 100)) | |||||
return d | return d | ||||