From 3795545674daa2538657600ee32693e029f91ad8 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sat, 6 Jun 2020 11:52:44 +0200 Subject: [PATCH] Start icosagon. --- requirements.txt | 3 + src/icosagon/__init__.py | 1 + src/icosagon/data.py | 127 ++++++++++++++++++++++++++++++++++++ tests/icosagon/test_data.py | 49 ++++++++++++++ 4 files changed, 180 insertions(+) create mode 100644 requirements.txt create mode 100644 src/icosagon/__init__.py create mode 100644 src/icosagon/data.py create mode 100644 tests/icosagon/test_data.py diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..89425dd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy +torch +dataclasses diff --git a/src/icosagon/__init__.py b/src/icosagon/__init__.py new file mode 100644 index 0000000..dc4d081 --- /dev/null +++ b/src/icosagon/__init__.py @@ -0,0 +1 @@ +from .data import Data diff --git a/src/icosagon/data.py b/src/icosagon/data.py new file mode 100644 index 0000000..9f696ca --- /dev/null +++ b/src/icosagon/data.py @@ -0,0 +1,127 @@ +from collections import defaultdict +from dataclasses import dataclass +import torch + + +@dataclass +class NodeType(object): + name: str + count: int + + +@dataclass +class RelationType(object): + name: str + node_type_row: int + node_type_column: int + adjacency_matrix: torch.Tensor + is_autogenerated: bool + + +class Data(object): + def __init__(self) -> None: + self.node_types = [] + self.relation_types = defaultdict(lambda: defaultdict(list)) + + def add_node_type(self, name: str, count: int) -> None: + name = str(name) + count = int(count) + if not name: + raise ValueError('You must provide a non-empty node type name') + if count <= 0: + raise ValueError('You must provide a positive node count') + self.node_types.append(NodeType(name, count)) + + def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, + adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, + two_way: bool = True) -> None: + + name = str(name) + node_type_row = int(node_type_row) + node_type_column = int(node_type_column) + + if node_type_row < 0 or node_type_row > len(self.node_types): + raise ValueError('node_type_row outside of the valid range of node types') + + if node_type_column < 0 or node_type_column > len(self.node_types): + raise ValueError('node_type_column outside of the valid range of node types') + + if not isinstance(adjacency_matrix, torch.Tensor): + raise ValueError('adjacency_matrix must be a torch.Tensor') + + if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor): + raise ValueError('adjacency_matrix_backward must be a torch.Tensor') + + if adjacency_matrix.shape != (self.node_types[node_type_row].count, + self.node_types[node_type_column].count): + raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') + + if adjacency_matrix_backward is not None and \ + adjacency_matrix_backward.shape != (self.node_types[node_type_column].count, + self.node_types[node_type_row].count): + raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') + + two_way = bool(two_way) + + if node_type_row == node_type_column and \ + adjacency_matrix_backward is not None: + raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') + + self.relation_types[node_type_row][node_type_column].append( + RelationType(name, node_type_row, node_type_column, + adjacency_matrix, False)) + + if node_type_row != node_type_column and two_way: + if adjacency_matrix_backward is None: + adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) + self.relation_types[node_type_column][node_type_row].append( + RelationType(name, node_type_column, node_type_row, + adjacency_matrix_backward, True)) + + def __repr__(self): + n = len(self.node_types) + if n == 0: + return 'Empty Icosagon Data' + s = '' + s += 'Icosagon Data with:\n' + s += '- ' + str(n) + ' node type(s):\n' + for nt in self.node_types: + s += ' - ' + nt.name + '\n' + if len(self.relation_types) == 0: + s += '- No relation types\n' + return s.strip() + + s_1 = '' + count = 0 + for i in range(n): + for j in range(n): + if i not in self.relation_types or \ + j not in self.relation_types[i]: + continue + + s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ + self.node_types[j].name + ':\n' + + for r in self.relation_types[i][j]: + if r.is_autogenerated: + continue + s_1 += ' - ' + r.name + '\n' + count += 1 + + s += '- %d relation type(s):\n' % count + s += s_1 + + return s.strip() + + # n = sum(map(len, self.relation_types)) + # + # for i in range(n): + # for j in range(n): + # key = (i, j) + # if key not in self.relation_types: + # continue + # rels = self.relation_types[key] + # + # for r in rels: + # + # return s.strip() diff --git a/tests/icosagon/test_data.py b/tests/icosagon/test_data.py new file mode 100644 index 0000000..ef64e7b --- /dev/null +++ b/tests/icosagon/test_data.py @@ -0,0 +1,49 @@ +from icosagon import Data +import torch +import pytest + + +def test_data_01(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + dummy_0 = torch.zeros((100, 1000)) + dummy_1 = torch.zeros((1000, 100)) + dummy_2 = torch.zeros((100, 100)) + dummy_3 = torch.zeros((1000, 1000)) + d.add_relation_type('Target', 1, 0, dummy_0) + d.add_relation_type('Interaction', 0, 0, dummy_3) + d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2) + d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2) + d.add_relation_type('Side Effect: Death', 1, 1, dummy_2) + print(d) + + +def test_data_02(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + dummy_0 = torch.zeros((100, 1000)) + dummy_1 = torch.zeros((1000, 100)) + dummy_2 = torch.zeros((100, 100)) + dummy_3 = torch.zeros((1000, 1000)) + with pytest.raises(ValueError): + d.add_relation_type('Target', 1, 0, dummy_1) + with pytest.raises(ValueError): + d.add_relation_type('Interaction', 0, 0, dummy_2) + with pytest.raises(ValueError): + d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3) + with pytest.raises(ValueError): + d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3) + with pytest.raises(ValueError): + d.add_relation_type('Side Effect: Death', 1, 1, dummy_3) + print(d) + + +def test_data_03(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + with pytest.raises(ValueError): + d.add_relation_type('Target', 1, 0, None) + print(d)