|
|
@@ -0,0 +1,38 @@ |
|
|
|
class Data(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.node_types = []
|
|
|
|
self.relation_types = []
|
|
|
|
|
|
|
|
def add_node_type(self, name):
|
|
|
|
self.node_types.append(name)
|
|
|
|
|
|
|
|
def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
|
|
|
|
n = len(self.node_types)
|
|
|
|
if node_type_row >= n or node_type_column >= n:
|
|
|
|
raise ValueError('Node type index out of bounds, add node type first')
|
|
|
|
self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
n = len(self.node_types)
|
|
|
|
if n == 0:
|
|
|
|
return 'Empty GNN Data'
|
|
|
|
s = ''
|
|
|
|
s += 'GNN Data with:\n'
|
|
|
|
s += '- ' + str(n) + ' node type(s):\n'
|
|
|
|
for nt in self.node_types:
|
|
|
|
s += ' - ' + nt + '\n'
|
|
|
|
if len(self.relation_types) == 0:
|
|
|
|
s += '- No relation types\n'
|
|
|
|
return s.strip()
|
|
|
|
s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
|
|
|
|
for i in range(n):
|
|
|
|
for j in range(n):
|
|
|
|
rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
|
|
|
|
if len(rels) == 0:
|
|
|
|
continue
|
|
|
|
# dir = '<->' if i == j else '->'
|
|
|
|
dir = '--'
|
|
|
|
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n'
|
|
|
|
for r in rels:
|
|
|
|
s += ' - ' + r[3] + '\n'
|
|
|
|
return s.strip()
|