|
|
@@ -3,33 +3,53 @@ from .decode import BilinearDecoder |
|
|
|
from .weights import init_glorot
|
|
|
|
|
|
|
|
|
|
|
|
class NodeType(object):
|
|
|
|
def __init__(self, name, count):
|
|
|
|
self.name = name
|
|
|
|
self.count = count
|
|
|
|
|
|
|
|
|
|
|
|
class RelationType(object):
|
|
|
|
def __init__(self, name, node_type_row, node_type_column,
|
|
|
|
adjacency_matrix):
|
|
|
|
self.name = name
|
|
|
|
self.node_type_row = node_type_row
|
|
|
|
self.node_type_column = node_type_column
|
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
|
|
|
|
|
|
|
|
|
class Data(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.node_types = []
|
|
|
|
self.relation_types = []
|
|
|
|
self.decoder_types = defaultdict(lambda: BilinearDecoder)
|
|
|
|
self.latent_node = []
|
|
|
|
self.relation_types = defaultdict(list)
|
|
|
|
# self.decoder_types = defaultdict(lambda: BilinearDecoder)
|
|
|
|
# self.latent_node = []
|
|
|
|
|
|
|
|
def add_node_type(self, name, count, latent_length):
|
|
|
|
self.node_types.append(name)
|
|
|
|
self.latent_node.append(init_glorot(count, latent_length))
|
|
|
|
def add_node_type(self, name, count): # , latent_length):
|
|
|
|
self.node_types.append(NodeType(name, count))
|
|
|
|
# self.latent_node.append(init_glorot(count, latent_length))
|
|
|
|
|
|
|
|
def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
|
|
|
|
def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
|
|
|
|
n = len(self.node_types)
|
|
|
|
if node_type_row >= n or node_type_column >= n:
|
|
|
|
raise ValueError('Node type index out of bounds, add node type first')
|
|
|
|
self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
|
|
|
|
_ = self.decoder_types[(node_type_row, node_type_column)]
|
|
|
|
key = (node_type_row, node_type_column)
|
|
|
|
self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
|
|
|
|
# _ = self.decoder_types[(node_type_row, node_type_column)]
|
|
|
|
|
|
|
|
def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
|
|
|
|
if (node_type_row, node_type_column) not in self.decoder_types:
|
|
|
|
raise ValueError('Relation type not found, add relation first')
|
|
|
|
self.decoder_types[(node_type_row, node_type_column)] = decoder_class
|
|
|
|
#def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
|
|
|
|
# if (node_type_row, node_type_column) not in self.decoder_types:
|
|
|
|
# raise ValueError('Relation type not found, add relation first')
|
|
|
|
# self.decoder_types[(node_type_row, node_type_column)] = decoder_class
|
|
|
|
|
|
|
|
def get_adjacency_matrices(self, node_type_row, node_type_column):
|
|
|
|
rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
|
|
|
|
if len(rels) == 0:
|
|
|
|
|
|
|
|
# rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
|
|
|
|
key = (node_type_row, node_type_column)
|
|
|
|
if key not in self.relation_types:
|
|
|
|
raise ValueError('Relation type not found')
|
|
|
|
rels = self.relation_types[key]
|
|
|
|
rels = list(map(lambda a: a.adjacency_matrix, rels))
|
|
|
|
return rels
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
@@ -40,19 +60,25 @@ class Data(object): |
|
|
|
s += 'GNN Data with:\n'
|
|
|
|
s += '- ' + str(n) + ' node type(s):\n'
|
|
|
|
for nt in self.node_types:
|
|
|
|
s += ' - ' + nt + '\n'
|
|
|
|
s += ' - ' + nt.name + '\n'
|
|
|
|
if len(self.relation_types) == 0:
|
|
|
|
s += '- No relation types\n'
|
|
|
|
return s.strip()
|
|
|
|
s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
|
|
|
|
n = sum(map(len, self.relation_types))
|
|
|
|
s += '- ' + str(n) + ' relation type(s):\n'
|
|
|
|
for i in range(n):
|
|
|
|
for j in range(n):
|
|
|
|
rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
|
|
|
|
if len(rels) == 0:
|
|
|
|
key = (i, j)
|
|
|
|
if key not in self.relation_types:
|
|
|
|
continue
|
|
|
|
rels = self.relation_types[key]
|
|
|
|
# rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
|
|
|
|
#if len(rels) == 0:
|
|
|
|
# continue
|
|
|
|
# dir = '<->' if i == j else '->'
|
|
|
|
dir = '--'
|
|
|
|
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n'
|
|
|
|
s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
|
|
|
|
#' (' + self.decoder_types[(i, j)].__name__ + '):\n'
|
|
|
|
for r in rels:
|
|
|
|
s += ' - ' + r[3] + '\n'
|
|
|
|
s += ' - ' + r.name + '\n'
|
|
|
|
return s.strip()
|