From fab210448d268d8833ace924a95c80b918924fbb Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 3 Jun 2020 16:45:03 +0200 Subject: [PATCH] Correct support for assymetric relations in Data. --- src/decagon_pytorch/data.py | 48 ++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index a375f75..762b1a6 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -16,11 +16,16 @@ class NodeType(object): class RelationType(object): def __init__(self, name, node_type_row, node_type_column, - adjacency_matrix): + adjacency_matrix, adjacency_matrix_transposed): + + if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape: + raise ValueError('adjacency_matrix_transposed has incorrect shape') + self.name = name self.node_type_row = node_type_row self.node_type_column = node_type_column self.adjacency_matrix = adjacency_matrix + self.adjacency_matrix_transposed = adjacency_matrix_transposed def get_adjacency_matrix(node_type_row, node_type_column): if self.node_type_row == node_type_row and \ @@ -29,7 +34,10 @@ class RelationType(object): elif self.node_type_row == node_type_column and \ self.node_type_column == node_type_row: - return self.adjacency_matrix.transpose(0, 1) + if self.adjacency_matrix_transposed: + return self.adjacency_matrix_transposed + else: + return self.adjacency_matrix.transpose(0, 1) else: raise ValueError('Specified row/column types do not correspond to this relation') @@ -39,37 +47,27 @@ class Data(object): def __init__(self): self.node_types = [] self.relation_types = defaultdict(list) - # self.decoder_types = defaultdict(lambda: BilinearDecoder) - # self.latent_node = [] def add_node_type(self, name, count): # , latent_length): self.node_types.append(NodeType(name, count)) - # self.latent_node.append(init_glorot(count, latent_length)) - def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix): + def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None): n = len(self.node_types) if node_type_row >= n or node_type_column >= n: raise ValueError('Node type index out of bounds, add node type first') key = (node_type_row, node_type_column) if adjacency_matrix is not None and not adjacency_matrix.is_sparse: adjacency_matrix = adjacency_matrix.to_sparse() - self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix)) - # _ = self.decoder_types[(node_type_row, node_type_column)] - - #def set_decoder_type(self, node_type_row, node_type_column, decoder_class): - # if (node_type_row, node_type_column) not in self.decoder_types: - # raise ValueError('Relation type not found, add relation first') - # self.decoder_types[(node_type_row, node_type_column)] = decoder_class + self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed)) def get_adjacency_matrices(self, node_type_row, node_type_column): - # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) - key = (node_type_row, node_type_column) - if key not in self.relation_types: - raise ValueError('Relation type not found') - rels = self.relation_types[key] - rels = list(map(lambda a: a.adjacency_matrix, rels)) - return rels - + res = [] + for (i, j), rels in self.relation_types.items(): + if node_type_row not in [i, j] and node_type_column not in [i, j]: + continue + for r in rels: + res.append(r.get_adjacency_matrix(node_type_row, node_type_column)) + return res def __repr__(self): n = len(self.node_types) @@ -91,13 +89,7 @@ class Data(object): if key not in self.relation_types: continue rels = self.relation_types[key] - # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) - #if len(rels) == 0: - # continue - # dir = '<->' if i == j else '->' - dir = '--' - s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n' - #' (' + self.decoder_types[(i, j)].__name__ + '):\n' + s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n' for r in rels: s += ' - ' + r.name + '\n' return s.strip()