|
|
@@ -0,0 +1,127 @@ |
|
|
|
from collections import defaultdict
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class NodeType(object):
|
|
|
|
name: str
|
|
|
|
count: int
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class RelationType(object):
|
|
|
|
name: str
|
|
|
|
node_type_row: int
|
|
|
|
node_type_column: int
|
|
|
|
adjacency_matrix: torch.Tensor
|
|
|
|
is_autogenerated: bool
|
|
|
|
|
|
|
|
|
|
|
|
class Data(object):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.node_types = []
|
|
|
|
self.relation_types = defaultdict(lambda: defaultdict(list))
|
|
|
|
|
|
|
|
def add_node_type(self, name: str, count: int) -> None:
|
|
|
|
name = str(name)
|
|
|
|
count = int(count)
|
|
|
|
if not name:
|
|
|
|
raise ValueError('You must provide a non-empty node type name')
|
|
|
|
if count <= 0:
|
|
|
|
raise ValueError('You must provide a positive node count')
|
|
|
|
self.node_types.append(NodeType(name, count))
|
|
|
|
|
|
|
|
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
|
|
|
|
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
|
|
|
|
two_way: bool = True) -> None:
|
|
|
|
|
|
|
|
name = str(name)
|
|
|
|
node_type_row = int(node_type_row)
|
|
|
|
node_type_column = int(node_type_column)
|
|
|
|
|
|
|
|
if node_type_row < 0 or node_type_row > len(self.node_types):
|
|
|
|
raise ValueError('node_type_row outside of the valid range of node types')
|
|
|
|
|
|
|
|
if node_type_column < 0 or node_type_column > len(self.node_types):
|
|
|
|
raise ValueError('node_type_column outside of the valid range of node types')
|
|
|
|
|
|
|
|
if not isinstance(adjacency_matrix, torch.Tensor):
|
|
|
|
raise ValueError('adjacency_matrix must be a torch.Tensor')
|
|
|
|
|
|
|
|
if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
|
|
|
|
raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
|
|
|
|
|
|
|
|
if adjacency_matrix.shape != (self.node_types[node_type_row].count,
|
|
|
|
self.node_types[node_type_column].count):
|
|
|
|
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
|
|
|
|
|
|
|
|
if adjacency_matrix_backward is not None and \
|
|
|
|
adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
|
|
|
|
self.node_types[node_type_row].count):
|
|
|
|
raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
|
|
|
|
|
|
|
|
two_way = bool(two_way)
|
|
|
|
|
|
|
|
if node_type_row == node_type_column and \
|
|
|
|
adjacency_matrix_backward is not None:
|
|
|
|
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
|
|
|
|
|
|
|
|
self.relation_types[node_type_row][node_type_column].append(
|
|
|
|
RelationType(name, node_type_row, node_type_column,
|
|
|
|
adjacency_matrix, False))
|
|
|
|
|
|
|
|
if node_type_row != node_type_column and two_way:
|
|
|
|
if adjacency_matrix_backward is None:
|
|
|
|
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
|
|
|
|
self.relation_types[node_type_column][node_type_row].append(
|
|
|
|
RelationType(name, node_type_column, node_type_row,
|
|
|
|
adjacency_matrix_backward, True))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
n = len(self.node_types)
|
|
|
|
if n == 0:
|
|
|
|
return 'Empty Icosagon Data'
|
|
|
|
s = ''
|
|
|
|
s += 'Icosagon Data with:\n'
|
|
|
|
s += '- ' + str(n) + ' node type(s):\n'
|
|
|
|
for nt in self.node_types:
|
|
|
|
s += ' - ' + nt.name + '\n'
|
|
|
|
if len(self.relation_types) == 0:
|
|
|
|
s += '- No relation types\n'
|
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
s_1 = ''
|
|
|
|
count = 0
|
|
|
|
for i in range(n):
|
|
|
|
for j in range(n):
|
|
|
|
if i not in self.relation_types or \
|
|
|
|
j not in self.relation_types[i]:
|
|
|
|
continue
|
|
|
|
|
|
|
|
s_1 += ' - ' + self.node_types[i].name + ' -- ' + \
|
|
|
|
self.node_types[j].name + ':\n'
|
|
|
|
|
|
|
|
for r in self.relation_types[i][j]:
|
|
|
|
if r.is_autogenerated:
|
|
|
|
continue
|
|
|
|
s_1 += ' - ' + r.name + '\n'
|
|
|
|
count += 1
|
|
|
|
|
|
|
|
s += '- %d relation type(s):\n' % count
|
|
|
|
s += s_1
|
|
|
|
|
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
# n = sum(map(len, self.relation_types))
|
|
|
|
#
|
|
|
|
# for i in range(n):
|
|
|
|
# for j in range(n):
|
|
|
|
# key = (i, j)
|
|
|
|
# if key not in self.relation_types:
|
|
|
|
# continue
|
|
|
|
# rels = self.relation_types[key]
|
|
|
|
#
|
|
|
|
# for r in rels:
|
|
|
|
#
|
|
|
|
# return s.strip()
|