From b5d2f8fcda9288394504427477eb6514310c73a0 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 9 Jun 2020 15:23:43 +0200 Subject: [PATCH] Start implementing RelationFamily. --- src/icosagon/data.py | 161 +++++++++++++++++++++++++----------- tests/icosagon/test_data.py | 33 +++++--- 2 files changed, 133 insertions(+), 61 deletions(-) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 86b17d1..cd55b29 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -10,7 +10,10 @@ import torch from typing import List, \ Dict, \ Tuple, \ - Any + Any, \ + Type +from .decode import DEDICOMDecoder, \ + BilinearDecoder @dataclass @@ -25,25 +28,33 @@ class RelationType(object): node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor + two_way: bool hints: Dict[str, Any] = field(default_factory=dict) -class Data(object): - node_types: List[NodeType] - relation_types: Dict[Tuple[int, int], List[RelationType]] +class RelationFamily(object): + def __init__(self, + data: 'Data', + name: str, + node_type_row: int, + node_type_column: int, + is_symmetric: bool, + decoder_class: Type) -> None: - def __init__(self) -> None: - self.node_types = [] - self.relation_types = defaultdict(list) + if not is_symmetric and \ + decoder_class != DEDICOMDecoder and \ + decoder_class != BilinearDecoder: + raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') - def add_node_type(self, name: str, count: int) -> None: - name = str(name) - count = int(count) - if not name: - raise ValueError('You must provide a non-empty node type name') - if count <= 0: - raise ValueError('You must provide a positive node count') - self.node_types.append(NodeType(name, count)) + self.data = data + self.name = name + self.node_type_row = node_type_row + self.node_type_column = node_type_column + self.is_symmetric = is_symmetric + self.decoder_class = decoder_class + + self.relation_types = { (node_type_row, node_type_column): [], + (node_type_column, node_type_row): [] } def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, @@ -53,45 +64,110 @@ class Data(object): node_type_row = int(node_type_row) node_type_column = int(node_type_column) - if node_type_row < 0 or node_type_row > len(self.node_types): + if (node_type_row, node_type_column) not in self.relation_types: + raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') + + if node_type_row < 0 or node_type_row >= len(self.data.node_types): raise ValueError('node_type_row outside of the valid range of node types') - if node_type_column < 0 or node_type_column > len(self.node_types): + if node_type_column < 0 or node_type_column >= len(self.data.node_types): raise ValueError('node_type_column outside of the valid range of node types') if not isinstance(adjacency_matrix, torch.Tensor): raise ValueError('adjacency_matrix must be a torch.Tensor') - if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor): + if adjacency_matrix_backward is not None \ + and not isinstance(adjacency_matrix_backward, torch.Tensor): raise ValueError('adjacency_matrix_backward must be a torch.Tensor') - if adjacency_matrix.shape != (self.node_types[node_type_row].count, - self.node_types[node_type_column].count): + if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, + self.data.node_types[node_type_column].count): raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') if adjacency_matrix_backward is not None and \ - adjacency_matrix_backward.shape != (self.node_types[node_type_column].count, - self.node_types[node_type_row].count): + adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, + self.data.node_types[node_type_row].count): raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') - two_way = bool(two_way) - if node_type_row == node_type_column and \ adjacency_matrix_backward is not None: raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') - self.relation_types[node_type_row, node_type_column].append( - RelationType(name, node_type_row, node_type_column, - adjacency_matrix)) + if self.is_symmetric and adjacency_matrix_backward is not None: + raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') + + if self.is_symmetric and node_type_row == node_type_column and \ + not torch.all(adjacency_matrix == adjacency_matrix.transpose(0, 1)): + raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') + + two_way = bool(two_way) if node_type_row != node_type_column and two_way: - hints = { 'display': False } + print('%d != %d' % (node_type_row, node_type_column)) if adjacency_matrix_backward is None: adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - hints['symmetric'] = True self.relation_types[node_type_column, node_type_row].append( RelationType(name, node_type_column, node_type_row, - adjacency_matrix_backward, hints)) + adjacency_matrix_backward, two_way, { 'display': False })) + + self.relation_types[node_type_row, node_type_column].append( + RelationType(name, node_type_row, node_type_column, + adjacency_matrix, two_way)) + + def node_name(self, index): + return self.data.node_types[index].name + + def __repr__(self): + s = 'Relation family %s' % self.name + + for (node_type_row, node_type_column), rels in self.relation_types.items(): + for r in rels: + if 'display' in r.hints and not r.hints['display']: + continue + s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ + (self.node_name(node_type_row), self.node_name(node_type_column))) + + return s + + def repr_indented(self): + s = ' - %s' % self.name + + for (node_type_row, node_type_column), rels in self.relation_types.items(): + for r in rels: + if 'display' in r.hints and not r.hints['display']: + continue + s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ + (self.node_name(node_type_row), self.node_name(node_type_column))) + + return s + + +class Data(object): + node_types: List[NodeType] + relation_types: Dict[Tuple[int, int], List[RelationType]] + + def __init__(self) -> None: + self.node_types = [] + self.relation_families = [] + + def add_node_type(self, name: str, count: int) -> None: + name = str(name) + count = int(count) + if not name: + raise ValueError('You must provide a non-empty node type name') + if count <= 0: + raise ValueError('You must provide a positive node count') + self.node_types.append(NodeType(name, count)) + + def add_relation_family(self, name: str, node_type_row: int, + node_type_column: int, is_symmetric: bool, + decoder_class: Type = DEDICOMDecoder): + + fam = RelationFamily(self, name, node_type_row, node_type_column, + is_symmetric, decoder_class) + self.relation_families.append(fam) + + return fam def __repr__(self): n = len(self.node_types) @@ -102,27 +178,12 @@ class Data(object): s += '- ' + str(n) + ' node type(s):\n' for nt in self.node_types: s += ' - ' + nt.name + '\n' - if len(self.relation_types) == 0: - s += '- No relation types\n' + if len(self.relation_families) == 0: + s += '- No relation families\n' return s.strip() - s_1 = '' - count = 0 - for node_type_row in range(n): - for node_type_column in range(n): - if (node_type_row, node_type_column) not in self.relation_types: - continue - - s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ - self.node_types[node_type_column].name + ':\n' - - for r in self.relation_types[node_type_row, node_type_column]: - if not r.hints.get('display', True): - continue - s_1 += ' - ' + r.name + '\n' - count += 1 - - s += '- %d relation type(s):\n' % count - s += s_1 + s += '- %d relation families:\n' % len(self.relation_families) + for fam in self.relation_families: + s += fam.repr_indented() + '\n' return s.strip() diff --git a/tests/icosagon/test_data.py b/tests/icosagon/test_data.py index 4ec8164..d67c1c1 100644 --- a/tests/icosagon/test_data.py +++ b/tests/icosagon/test_data.py @@ -17,11 +17,14 @@ def test_data_01(): dummy_1 = torch.zeros((1000, 100)) dummy_2 = torch.zeros((100, 100)) dummy_3 = torch.zeros((1000, 1000)) - d.add_relation_type('Target', 1, 0, dummy_0) - d.add_relation_type('Interaction', 0, 0, dummy_3) - d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) - d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) - d.add_relation_type('Side Effect: Death', 1, 1, dummy_2) + fam = d.add_relation_family('Drug-Gene', 1, 0, True) + fam.add_relation_type('Target', 1, 0, dummy_0) + fam = d.add_relation_family('Gene-Gene', 0, 0, True) + fam.add_relation_type('Interaction', 0, 0, dummy_3) + fam = d.add_relation_family('Drug-Drug', 1, 1, True) + fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) + fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) + fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2) print(d) @@ -29,20 +32,27 @@ def test_data_02(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) + dummy_0 = torch.zeros((100, 1000)) dummy_1 = torch.zeros((1000, 100)) dummy_2 = torch.zeros((100, 100)) dummy_3 = torch.zeros((1000, 1000)) + + fam = d.add_relation_family('Drug-Gene', 1, 0, True) with pytest.raises(ValueError): - d.add_relation_type('Target', 1, 0, dummy_1) + fam.add_relation_type('Target', 1, 0, dummy_1) + + fam = d.add_relation_family('Gene-Gene', 0, 0, True) with pytest.raises(ValueError): - d.add_relation_type('Interaction', 0, 0, dummy_2) + fam.add_relation_type('Interaction', 0, 0, dummy_2) + + fam = d.add_relation_family('Drug-Drug', 1, 1, True) with pytest.raises(ValueError): - d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) with pytest.raises(ValueError): - d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) with pytest.raises(ValueError): - d.add_relation_type('Side Effect: Death', 1, 1, dummy_3) + fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3) print(d) @@ -50,6 +60,7 @@ def test_data_03(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) + fam = d.add_relation_family('Drug-Gene', 1, 0, True) with pytest.raises(ValueError): - d.add_relation_type('Target', 1, 0, None) + fam.add_relation_type('Target', 1, 0, None) print(d)