@@ -80,22 +80,12 @@ class RelationFamily(RelationFamilyBase): | |||
self.relation_types = [] | |||
def add_relation_type(self, | |||
name: str, node_type_row: int, node_type_column: int, | |||
adjacency_matrix: torch.Tensor, | |||
name: str, adjacency_matrix: torch.Tensor, | |||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
name = str(name) | |||
node_type_row = int(node_type_row) | |||
node_type_column = int(node_type_column) | |||
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | |||
if node_type_row < 0 or node_type_row >= len(self.data.node_types): | |||
raise ValueError('node_type_row outside of the valid range of node types') | |||
if node_type_column < 0 or node_type_column >= len(self.data.node_types): | |||
raise ValueError('node_type_column outside of the valid range of node types') | |||
node_type_row = self.node_type_row | |||
node_type_column = self.node_type_column | |||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
@@ -104,13 +94,6 @@ class RelationFamily(RelationFamilyBase): | |||
not isinstance(adjacency_matrix, torch.Tensor): | |||
raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
# if isinstance(adjacency_matrix_backward, str) and \ | |||
# adjacency_matrix_backward == 'symmetric': | |||
# if self.is_symmetric: | |||
# adjacency_matrix_backward = None | |||
# else: | |||
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
if adjacency_matrix_backward is not None \ | |||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
@@ -154,7 +137,7 @@ class RelationFamily(RelationFamilyBase): | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.is_symmetric \ | |||
or self.node_type_row == self.node_type_column \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
@@ -167,7 +150,7 @@ class RelationFamily(RelationFamilyBase): | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.is_symmetric \ | |||
or self.node_type_row == self.node_type_column \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
@@ -195,6 +178,17 @@ class Data(object): | |||
node_type_column: int, is_symmetric: bool, | |||
decoder_class: Type = DEDICOMDecoder): | |||
name = str(name) | |||
node_type_row = int(node_type_row) | |||
node_type_column = int(node_type_column) | |||
is_symmetric = bool(is_symmetric) | |||
if node_type_row < 0 or node_type_row >= len(self.node_types): | |||
raise ValueError('node_type_row outside of the valid range of node types') | |||
if node_type_column < 0 or node_type_column >= len(self.node_types): | |||
raise ValueError('node_type_column outside of the valid range of node types') | |||
fam = RelationFamily(self, name, node_type_row, node_type_column, | |||
is_symmetric, decoder_class) | |||
self.relation_families.append(fam) | |||
@@ -26,19 +26,19 @@ def _some_data_with_interactions(): | |||
d.add_node_type('Drug', 100) | |||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | |||
fam.add_relation_type('Target', 1, 0, | |||
fam.add_relation_type('Target', | |||
torch.rand((100, 1000), dtype=torch.float32).round()) | |||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | |||
fam.add_relation_type('Interaction', 0, 0, | |||
fam.add_relation_type('Interaction', | |||
_symmetric_random(1000, 1000)) | |||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | |||
fam.add_relation_type('Side Effect: Nausea', 1, 1, | |||
fam.add_relation_type('Side Effect: Nausea', | |||
_symmetric_random(100, 100)) | |||
fam.add_relation_type('Side Effect: Infertility', 1, 1, | |||
fam.add_relation_type('Side Effect: Infertility', | |||
_symmetric_random(100, 100)) | |||
fam.add_relation_type('Side Effect: Death', 1, 1, | |||
fam.add_relation_type('Side Effect: Death', | |||
_symmetric_random(100, 100)) | |||
return d | |||
@@ -97,7 +97,7 @@ def test_decagon_layer_04(): | |||
d = Data() | |||
d.add_node_type('Dummy', 100) | |||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | |||
fam.add_relation_type('Dummy Relation', 0, 0, | |||
fam.add_relation_type('Dummy Relation', | |||
_symmetric_random(100, 100).to_sparse()) | |||
in_layer = OneHotInputLayer(d) | |||
@@ -139,9 +139,9 @@ def test_decagon_layer_05(): | |||
d = Data() | |||
d.add_node_type('Dummy', 100) | |||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) | |||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||
fam.add_relation_type('Dummy Relation 1', | |||
_symmetric_random(100, 100).to_sparse()) | |||
fam.add_relation_type('Dummy Relation 2', 0, 0, | |||
fam.add_relation_type('Dummy Relation 2', | |||
_symmetric_random(100, 100).to_sparse()) | |||
in_layer = OneHotInputLayer(d) | |||
@@ -4,11 +4,80 @@ | |||
# | |||
from icosagon import Data | |||
from icosagon.data import Data, \ | |||
_equal, \ | |||
RelationFamily | |||
from icosagon.decode import DEDICOMDecoder | |||
import torch | |||
import pytest | |||
def test_equal_01(): | |||
x = torch.rand((10, 10)) | |||
y = torch.rand((10, 10)).round().to_sparse() | |||
assert torch.all(_equal(x, x)) | |||
assert torch.all(_equal(y, y)) | |||
with pytest.raises(ValueError): | |||
_equal(x, y) | |||
z = torch.rand((10, 10)).round().to_sparse() | |||
assert not torch.all(_equal(y, z)) | |||
def test_relation_family_01(): | |||
d = Data() | |||
d.add_node_type('Whatever', 10) | |||
fam = RelationFamily(d, 'Dummy-Dummy', 0, 0, True, DEDICOMDecoder) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', None, None) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', 'bad-value', None) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', None, 'bad-value') | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', torch.rand((5, 5)), None) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', None, torch.rand((5, 5))) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), torch.rand((10, 10))) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), None) | |||
def test_relation_family_02(): | |||
d = Data() | |||
d.add_node_type('A', 10) | |||
d.add_node_type('B', 5) | |||
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('A-B', torch.rand((10, 5)).round(), | |||
torch.rand((5, 10)).round()) | |||
def test_relation_family_03(): | |||
d = Data() | |||
d.add_node_type('A', 10) | |||
d.add_node_type('B', 5) | |||
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) | |||
fam.add_relation_type('A-B', torch.rand((10, 5)).round()) | |||
assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \ | |||
fam.relation_types[0].adjacency_matrix_backward) | |||
def test_data_01(): | |||
d = Data() | |||
d.add_node_type('Gene', 1000) | |||
@@ -17,14 +86,18 @@ def test_data_01(): | |||
dummy_1 = torch.zeros((1000, 100)) | |||
dummy_2 = torch.zeros((100, 100)) | |||
dummy_3 = torch.zeros((1000, 1000)) | |||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | |||
fam.add_relation_type('Target', 1, 0, dummy_0) | |||
fam.add_relation_type('Target', dummy_0) | |||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | |||
fam.add_relation_type('Interaction', 0, 0, dummy_3) | |||
fam.add_relation_type('Interaction', dummy_3) | |||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | |||
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) | |||
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) | |||
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2) | |||
fam.add_relation_type('Side Effect: Nausea', dummy_2) | |||
fam.add_relation_type('Side Effect: Infertility', dummy_2) | |||
fam.add_relation_type('Side Effect: Death', dummy_2) | |||
print(d) | |||
@@ -40,19 +113,19 @@ def test_data_02(): | |||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Target', 1, 0, dummy_1) | |||
fam.add_relation_type('Target', dummy_1) | |||
fam = d.add_relation_family('Gene-Gene', 0, 0, True) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Interaction', 0, 0, dummy_2) | |||
fam.add_relation_type('Interaction', dummy_2) | |||
fam = d.add_relation_family('Drug-Drug', 1, 1, True) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) | |||
fam.add_relation_type('Side Effect: Nausea', dummy_3) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) | |||
fam.add_relation_type('Side Effect: Infertility', dummy_3) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3) | |||
fam.add_relation_type('Side Effect: Death', dummy_3) | |||
print(d) | |||
@@ -62,5 +135,5 @@ def test_data_03(): | |||
d.add_node_type('Drug', 100) | |||
fam = d.add_relation_family('Drug-Gene', 1, 0, True) | |||
with pytest.raises(ValueError): | |||
fam.add_relation_type('Target', 1, 0, None) | |||
fam.add_relation_type('Target', None) | |||
print(d) |
@@ -22,7 +22,7 @@ def test_decode_layer_01(): | |||
d.add_node_type('Dummy', 100) | |||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||
fam.add_relation_type('Dummy Relation 1', | |||
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||
prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
@@ -151,7 +151,6 @@ def test_is_inner_product_symmetric_01(): | |||
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | |||
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | |||
assert isinstance(res_1, list) | |||
assert isinstance(res_2, list) | |||
@@ -11,15 +11,15 @@ def _some_data(): | |||
d.add_node_type('Drug', 100) | |||
fam = d.add_relation_family('Drug-Gene', 1, 0, False) | |||
fam.add_relation_type('Target', 1, 0, torch.rand(100, 1000)) | |||
fam.add_relation_type('Target', torch.rand(100, 1000)) | |||
fam = d.add_relation_family('Gene-Gene', 0, 0, False) | |||
fam.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000)) | |||
fam.add_relation_type('Interaction', torch.rand(1000, 1000)) | |||
fam = d.add_relation_family('Drug-Drug', 1, 1, False) | |||
fam.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100)) | |||
fam.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100)) | |||
fam.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100)) | |||
fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100)) | |||
fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100)) | |||
fam.add_relation_type('Side Effect: Death', torch.rand(100, 100)) | |||
return d | |||