| @@ -16,11 +16,16 @@ class NodeType(object): | |||||
| class RelationType(object): | class RelationType(object): | ||||
| def __init__(self, name, node_type_row, node_type_column, | def __init__(self, name, node_type_row, node_type_column, | ||||
| adjacency_matrix): | |||||
| adjacency_matrix, adjacency_matrix_transposed): | |||||
| if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape: | |||||
| raise ValueError('adjacency_matrix_transposed has incorrect shape') | |||||
| self.name = name | self.name = name | ||||
| self.node_type_row = node_type_row | self.node_type_row = node_type_row | ||||
| self.node_type_column = node_type_column | self.node_type_column = node_type_column | ||||
| self.adjacency_matrix = adjacency_matrix | self.adjacency_matrix = adjacency_matrix | ||||
| self.adjacency_matrix_transposed = adjacency_matrix_transposed | |||||
| def get_adjacency_matrix(node_type_row, node_type_column): | def get_adjacency_matrix(node_type_row, node_type_column): | ||||
| if self.node_type_row == node_type_row and \ | if self.node_type_row == node_type_row and \ | ||||
| @@ -29,7 +34,10 @@ class RelationType(object): | |||||
| elif self.node_type_row == node_type_column and \ | elif self.node_type_row == node_type_column and \ | ||||
| self.node_type_column == node_type_row: | self.node_type_column == node_type_row: | ||||
| return self.adjacency_matrix.transpose(0, 1) | |||||
| if self.adjacency_matrix_transposed: | |||||
| return self.adjacency_matrix_transposed | |||||
| else: | |||||
| return self.adjacency_matrix.transpose(0, 1) | |||||
| else: | else: | ||||
| raise ValueError('Specified row/column types do not correspond to this relation') | raise ValueError('Specified row/column types do not correspond to this relation') | ||||
| @@ -39,37 +47,27 @@ class Data(object): | |||||
| def __init__(self): | def __init__(self): | ||||
| self.node_types = [] | self.node_types = [] | ||||
| self.relation_types = defaultdict(list) | self.relation_types = defaultdict(list) | ||||
| # self.decoder_types = defaultdict(lambda: BilinearDecoder) | |||||
| # self.latent_node = [] | |||||
| def add_node_type(self, name, count): # , latent_length): | def add_node_type(self, name, count): # , latent_length): | ||||
| self.node_types.append(NodeType(name, count)) | self.node_types.append(NodeType(name, count)) | ||||
| # self.latent_node.append(init_glorot(count, latent_length)) | |||||
| def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix): | |||||
| def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None): | |||||
| n = len(self.node_types) | n = len(self.node_types) | ||||
| if node_type_row >= n or node_type_column >= n: | if node_type_row >= n or node_type_column >= n: | ||||
| raise ValueError('Node type index out of bounds, add node type first') | raise ValueError('Node type index out of bounds, add node type first') | ||||
| key = (node_type_row, node_type_column) | key = (node_type_row, node_type_column) | ||||
| if adjacency_matrix is not None and not adjacency_matrix.is_sparse: | if adjacency_matrix is not None and not adjacency_matrix.is_sparse: | ||||
| adjacency_matrix = adjacency_matrix.to_sparse() | adjacency_matrix = adjacency_matrix.to_sparse() | ||||
| self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix)) | |||||
| # _ = self.decoder_types[(node_type_row, node_type_column)] | |||||
| #def set_decoder_type(self, node_type_row, node_type_column, decoder_class): | |||||
| # if (node_type_row, node_type_column) not in self.decoder_types: | |||||
| # raise ValueError('Relation type not found, add relation first') | |||||
| # self.decoder_types[(node_type_row, node_type_column)] = decoder_class | |||||
| self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed)) | |||||
| def get_adjacency_matrices(self, node_type_row, node_type_column): | def get_adjacency_matrices(self, node_type_row, node_type_column): | ||||
| # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) | |||||
| key = (node_type_row, node_type_column) | |||||
| if key not in self.relation_types: | |||||
| raise ValueError('Relation type not found') | |||||
| rels = self.relation_types[key] | |||||
| rels = list(map(lambda a: a.adjacency_matrix, rels)) | |||||
| return rels | |||||
| res = [] | |||||
| for (i, j), rels in self.relation_types.items(): | |||||
| if node_type_row not in [i, j] and node_type_column not in [i, j]: | |||||
| continue | |||||
| for r in rels: | |||||
| res.append(r.get_adjacency_matrix(node_type_row, node_type_column)) | |||||
| return res | |||||
| def __repr__(self): | def __repr__(self): | ||||
| n = len(self.node_types) | n = len(self.node_types) | ||||
| @@ -91,13 +89,7 @@ class Data(object): | |||||
| if key not in self.relation_types: | if key not in self.relation_types: | ||||
| continue | continue | ||||
| rels = self.relation_types[key] | rels = self.relation_types[key] | ||||
| # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) | |||||
| #if len(rels) == 0: | |||||
| # continue | |||||
| # dir = '<->' if i == j else '->' | |||||
| dir = '--' | |||||
| s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n' | |||||
| #' (' + self.decoder_types[(i, j)].__name__ + '):\n' | |||||
| s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n' | |||||
| for r in rels: | for r in rels: | ||||
| s += ' - ' + r.name + '\n' | s += ' - ' + r.name + '\n' | ||||
| return s.strip() | return s.strip() | ||||