IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Start implementing RelationFamily.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
b5d2f8fcda
共有 2 个文件被更改,包括 133 次插入61 次删除
  1. +111
    -50
      src/icosagon/data.py
  2. +22
    -11
      tests/icosagon/test_data.py

+ 111
- 50
src/icosagon/data.py 查看文件

@@ -10,7 +10,10 @@ import torch
from typing import List, \
Dict, \
Tuple, \
Any
Any, \
Type
from .decode import DEDICOMDecoder, \
BilinearDecoder
@dataclass
@@ -25,25 +28,33 @@ class RelationType(object):
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
two_way: bool
hints: Dict[str, Any] = field(default_factory=dict)
class Data(object):
node_types: List[NodeType]
relation_types: Dict[Tuple[int, int], List[RelationType]]
class RelationFamily(object):
def __init__(self,
data: 'Data',
name: str,
node_type_row: int,
node_type_column: int,
is_symmetric: bool,
decoder_class: Type) -> None:
def __init__(self) -> None:
self.node_types = []
self.relation_types = defaultdict(list)
if not is_symmetric and \
decoder_class != DEDICOMDecoder and \
decoder_class != BilinearDecoder:
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
def add_node_type(self, name: str, count: int) -> None:
name = str(name)
count = int(count)
if not name:
raise ValueError('You must provide a non-empty node type name')
if count <= 0:
raise ValueError('You must provide a positive node count')
self.node_types.append(NodeType(name, count))
self.data = data
self.name = name
self.node_type_row = node_type_row
self.node_type_column = node_type_column
self.is_symmetric = is_symmetric
self.decoder_class = decoder_class
self.relation_types = { (node_type_row, node_type_column): [],
(node_type_column, node_type_row): [] }
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
@@ -53,45 +64,110 @@ class Data(object):
node_type_row = int(node_type_row)
node_type_column = int(node_type_column)
if node_type_row < 0 or node_type_row > len(self.node_types):
if (node_type_row, node_type_column) not in self.relation_types:
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family')
if node_type_row < 0 or node_type_row >= len(self.data.node_types):
raise ValueError('node_type_row outside of the valid range of node types')
if node_type_column < 0 or node_type_column > len(self.node_types):
if node_type_column < 0 or node_type_column >= len(self.data.node_types):
raise ValueError('node_type_column outside of the valid range of node types')
if not isinstance(adjacency_matrix, torch.Tensor):
raise ValueError('adjacency_matrix must be a torch.Tensor')
if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
if adjacency_matrix_backward is not None \
and not isinstance(adjacency_matrix_backward, torch.Tensor):
raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
if adjacency_matrix.shape != (self.node_types[node_type_row].count,
self.node_types[node_type_column].count):
if adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
self.data.node_types[node_type_column].count):
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
if adjacency_matrix_backward is not None and \
adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
self.node_types[node_type_row].count):
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
self.data.node_types[node_type_row].count):
raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
two_way = bool(two_way)
if node_type_row == node_type_column and \
adjacency_matrix_backward is not None:
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
self.relation_types[node_type_row, node_type_column].append(
RelationType(name, node_type_row, node_type_column,
adjacency_matrix))
if self.is_symmetric and adjacency_matrix_backward is not None:
raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
if self.is_symmetric and node_type_row == node_type_column and \
not torch.all(adjacency_matrix == adjacency_matrix.transpose(0, 1)):
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
two_way = bool(two_way)
if node_type_row != node_type_column and two_way:
hints = { 'display': False }
print('%d != %d' % (node_type_row, node_type_column))
if adjacency_matrix_backward is None:
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
hints['symmetric'] = True
self.relation_types[node_type_column, node_type_row].append(
RelationType(name, node_type_column, node_type_row,
adjacency_matrix_backward, hints))
adjacency_matrix_backward, two_way, { 'display': False }))
self.relation_types[node_type_row, node_type_column].append(
RelationType(name, node_type_row, node_type_column,
adjacency_matrix, two_way))
def node_name(self, index):
return self.data.node_types[index].name
def __repr__(self):
s = 'Relation family %s' % self.name
for (node_type_row, node_type_column), rels in self.relation_types.items():
for r in rels:
if 'display' in r.hints and not r.hints['display']:
continue
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
(self.node_name(node_type_row), self.node_name(node_type_column)))
return s
def repr_indented(self):
s = ' - %s' % self.name
for (node_type_row, node_type_column), rels in self.relation_types.items():
for r in rels:
if 'display' in r.hints and not r.hints['display']:
continue
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
(self.node_name(node_type_row), self.node_name(node_type_column)))
return s
class Data(object):
node_types: List[NodeType]
relation_types: Dict[Tuple[int, int], List[RelationType]]
def __init__(self) -> None:
self.node_types = []
self.relation_families = []
def add_node_type(self, name: str, count: int) -> None:
name = str(name)
count = int(count)
if not name:
raise ValueError('You must provide a non-empty node type name')
if count <= 0:
raise ValueError('You must provide a positive node count')
self.node_types.append(NodeType(name, count))
def add_relation_family(self, name: str, node_type_row: int,
node_type_column: int, is_symmetric: bool,
decoder_class: Type = DEDICOMDecoder):
fam = RelationFamily(self, name, node_type_row, node_type_column,
is_symmetric, decoder_class)
self.relation_families.append(fam)
return fam
def __repr__(self):
n = len(self.node_types)
@@ -102,27 +178,12 @@ class Data(object):
s += '- ' + str(n) + ' node type(s):\n'
for nt in self.node_types:
s += ' - ' + nt.name + '\n'
if len(self.relation_types) == 0:
s += '- No relation types\n'
if len(self.relation_families) == 0:
s += '- No relation families\n'
return s.strip()
s_1 = ''
count = 0
for node_type_row in range(n):
for node_type_column in range(n):
if (node_type_row, node_type_column) not in self.relation_types:
continue
s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \
self.node_types[node_type_column].name + ':\n'
for r in self.relation_types[node_type_row, node_type_column]:
if not r.hints.get('display', True):
continue
s_1 += ' - ' + r.name + '\n'
count += 1
s += '- %d relation type(s):\n' % count
s += s_1
s += '- %d relation families:\n' % len(self.relation_families)
for fam in self.relation_families:
s += fam.repr_indented() + '\n'
return s.strip()

+ 22
- 11
tests/icosagon/test_data.py 查看文件

@@ -17,11 +17,14 @@ def test_data_01():
dummy_1 = torch.zeros((1000, 100))
dummy_2 = torch.zeros((100, 100))
dummy_3 = torch.zeros((1000, 1000))
d.add_relation_type('Target', 1, 0, dummy_0)
d.add_relation_type('Interaction', 0, 0, dummy_3)
d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2)
d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2)
d.add_relation_type('Side Effect: Death', 1, 1, dummy_2)
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
fam.add_relation_type('Target', 1, 0, dummy_0)
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
fam.add_relation_type('Interaction', 0, 0, dummy_3)
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_2)
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_2)
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_2)
print(d)
@@ -29,20 +32,27 @@ def test_data_02():
d = Data()
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
dummy_0 = torch.zeros((100, 1000))
dummy_1 = torch.zeros((1000, 100))
dummy_2 = torch.zeros((100, 100))
dummy_3 = torch.zeros((1000, 1000))
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
with pytest.raises(ValueError):
d.add_relation_type('Target', 1, 0, dummy_1)
fam.add_relation_type('Target', 1, 0, dummy_1)
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
with pytest.raises(ValueError):
d.add_relation_type('Interaction', 0, 0, dummy_2)
fam.add_relation_type('Interaction', 0, 0, dummy_2)
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
with pytest.raises(ValueError):
d.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Nausea', 1, 1, dummy_3)
with pytest.raises(ValueError):
d.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Infertility', 1, 1, dummy_3)
with pytest.raises(ValueError):
d.add_relation_type('Side Effect: Death', 1, 1, dummy_3)
fam.add_relation_type('Side Effect: Death', 1, 1, dummy_3)
print(d)
@@ -50,6 +60,7 @@ def test_data_03():
d = Data()
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
with pytest.raises(ValueError):
d.add_relation_type('Target', 1, 0, None)
fam.add_relation_type('Target', 1, 0, None)
print(d)

正在加载...
取消
保存