from collections import defaultdict from dataclasses import dataclass import torch @dataclass class NodeType(object): name: str count: int @dataclass class RelationType(object): name: str node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor is_autogenerated: bool class Data(object): def __init__(self) -> None: self.node_types = [] self.relation_types = defaultdict(lambda: defaultdict(list)) def add_node_type(self, name: str, count: int) -> None: name = str(name) count = int(count) if not name: raise ValueError('You must provide a non-empty node type name') if count <= 0: raise ValueError('You must provide a positive node count') self.node_types.append(NodeType(name, count)) def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, two_way: bool = True) -> None: name = str(name) node_type_row = int(node_type_row) node_type_column = int(node_type_column) if node_type_row < 0 or node_type_row > len(self.node_types): raise ValueError('node_type_row outside of the valid range of node types') if node_type_column < 0 or node_type_column > len(self.node_types): raise ValueError('node_type_column outside of the valid range of node types') if not isinstance(adjacency_matrix, torch.Tensor): raise ValueError('adjacency_matrix must be a torch.Tensor') if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor): raise ValueError('adjacency_matrix_backward must be a torch.Tensor') if adjacency_matrix.shape != (self.node_types[node_type_row].count, self.node_types[node_type_column].count): raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') if adjacency_matrix_backward is not None and \ adjacency_matrix_backward.shape != (self.node_types[node_type_column].count, self.node_types[node_type_row].count): raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') two_way = bool(two_way) if node_type_row == node_type_column and \ adjacency_matrix_backward is not None: raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') self.relation_types[node_type_row][node_type_column].append( RelationType(name, node_type_row, node_type_column, adjacency_matrix, False)) if node_type_row != node_type_column and two_way: if adjacency_matrix_backward is None: adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) self.relation_types[node_type_column][node_type_row].append( RelationType(name, node_type_column, node_type_row, adjacency_matrix_backward, True)) def __repr__(self): n = len(self.node_types) if n == 0: return 'Empty Icosagon Data' s = '' s += 'Icosagon Data with:\n' s += '- ' + str(n) + ' node type(s):\n' for nt in self.node_types: s += ' - ' + nt.name + '\n' if len(self.relation_types) == 0: s += '- No relation types\n' return s.strip() s_1 = '' count = 0 for i in range(n): for j in range(n): if i not in self.relation_types or \ j not in self.relation_types[i]: continue s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ self.node_types[j].name + ':\n' for r in self.relation_types[i][j]: if r.is_autogenerated: continue s_1 += ' - ' + r.name + '\n' count += 1 s += '- %d relation type(s):\n' % count s += s_1 return s.strip() # n = sum(map(len, self.relation_types)) # # for i in range(n): # for j in range(n): # key = (i, j) # if key not in self.relation_types: # continue # rels = self.relation_types[key] # # for r in rels: # # return s.strip()