IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Correct support for assymetric relations in Data.

master
Stanislaw Adaszewski 3 years ago
parent
commit
fab210448d
1 changed files with 20 additions and 28 deletions
  1. +20
    -28
      src/decagon_pytorch/data.py

+ 20
- 28
src/decagon_pytorch/data.py View File

@@ -16,11 +16,16 @@ class NodeType(object):
class RelationType(object):
def __init__(self, name, node_type_row, node_type_column,
adjacency_matrix):
adjacency_matrix, adjacency_matrix_transposed):
if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
raise ValueError('adjacency_matrix_transposed has incorrect shape')
self.name = name
self.node_type_row = node_type_row
self.node_type_column = node_type_column
self.adjacency_matrix = adjacency_matrix
self.adjacency_matrix_transposed = adjacency_matrix_transposed
def get_adjacency_matrix(node_type_row, node_type_column):
if self.node_type_row == node_type_row and \
@@ -29,7 +34,10 @@ class RelationType(object):
elif self.node_type_row == node_type_column and \
self.node_type_column == node_type_row:
return self.adjacency_matrix.transpose(0, 1)
if self.adjacency_matrix_transposed:
return self.adjacency_matrix_transposed
else:
return self.adjacency_matrix.transpose(0, 1)
else:
raise ValueError('Specified row/column types do not correspond to this relation')
@@ -39,37 +47,27 @@ class Data(object):
def __init__(self):
self.node_types = []
self.relation_types = defaultdict(list)
# self.decoder_types = defaultdict(lambda: BilinearDecoder)
# self.latent_node = []
def add_node_type(self, name, count): # , latent_length):
self.node_types.append(NodeType(name, count))
# self.latent_node.append(init_glorot(count, latent_length))
def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None):
n = len(self.node_types)
if node_type_row >= n or node_type_column >= n:
raise ValueError('Node type index out of bounds, add node type first')
key = (node_type_row, node_type_column)
if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
adjacency_matrix = adjacency_matrix.to_sparse()
self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
# _ = self.decoder_types[(node_type_row, node_type_column)]
#def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
# if (node_type_row, node_type_column) not in self.decoder_types:
# raise ValueError('Relation type not found, add relation first')
# self.decoder_types[(node_type_row, node_type_column)] = decoder_class
self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed))
def get_adjacency_matrices(self, node_type_row, node_type_column):
# rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
key = (node_type_row, node_type_column)
if key not in self.relation_types:
raise ValueError('Relation type not found')
rels = self.relation_types[key]
rels = list(map(lambda a: a.adjacency_matrix, rels))
return rels
res = []
for (i, j), rels in self.relation_types.items():
if node_type_row not in [i, j] and node_type_column not in [i, j]:
continue
for r in rels:
res.append(r.get_adjacency_matrix(node_type_row, node_type_column))
return res
def __repr__(self):
n = len(self.node_types)
@@ -91,13 +89,7 @@ class Data(object):
if key not in self.relation_types:
continue
rels = self.relation_types[key]
# rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
#if len(rels) == 0:
# continue
# dir = '<->' if i == j else '->'
dir = '--'
s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
#' (' + self.decoder_types[(i, j)].__name__ + '):\n'
s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n'
for r in rels:
s += ' - ' + r.name + '\n'
return s.strip()

Loading…
Cancel
Save