from collections import defaultdict from .decode import BilinearDecoder from .weights import init_glorot class Data(object): def __init__(self): self.node_types = [] self.relation_types = [] self.decoder_types = defaultdict(lambda: BilinearDecoder) self.latent_node = [] def add_node_type(self, name, count, latent_length): self.node_types.append(name) self.latent_node.append(init_glorot(count, latent_length)) def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name): n = len(self.node_types) if node_type_row >= n or node_type_column >= n: raise ValueError('Node type index out of bounds, add node type first') self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name)) _ = self.decoder_types[(node_type_row, node_type_column)] def set_decoder_type(self, node_type_row, node_type_column, decoder_class): if (node_type_row, node_type_column) not in self.decoder_types: raise ValueError('Relation type not found, add relation first') self.decoder_types[(node_type_row, node_type_column)] = decoder_class def get_adjacency_matrices(self, node_type_row, node_type_column): rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) if len(rels) == 0: def __repr__(self): n = len(self.node_types) if n == 0: return 'Empty GNN Data' s = '' s += 'GNN Data with:\n' s += '- ' + str(n) + ' node type(s):\n' for nt in self.node_types: s += ' - ' + nt + '\n' if len(self.relation_types) == 0: s += '- No relation types\n' return s.strip() s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n' for i in range(n): for j in range(n): rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) if len(rels) == 0: continue # dir = '<->' if i == j else '->' dir = '--' s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n' for r in rels: s += ' - ' + r[3] + '\n' return s.strip()