|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- from collections import defaultdict
- from .decode import BilinearDecoder
- from .weights import init_glorot
-
-
- class Data(object):
- def __init__(self):
- self.node_types = []
- self.relation_types = []
- self.decoder_types = defaultdict(lambda: BilinearDecoder)
- self.latent_node = []
-
- def add_node_type(self, name, count, latent_length):
- self.node_types.append(name)
- self.latent_node.append(init_glorot(count, latent_length))
-
- def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
- n = len(self.node_types)
- if node_type_row >= n or node_type_column >= n:
- raise ValueError('Node type index out of bounds, add node type first')
- self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
- _ = self.decoder_types[(node_type_row, node_type_column)]
-
- def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
- if (node_type_row, node_type_column) not in self.decoder_types:
- raise ValueError('Relation type not found, add relation first')
- self.decoder_types[(node_type_row, node_type_column)] = decoder_class
-
- def get_adjacency_matrices(self, node_type_row, node_type_column):
- rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
- if len(rels) == 0:
-
-
-
- def __repr__(self):
- n = len(self.node_types)
- if n == 0:
- return 'Empty GNN Data'
- s = ''
- s += 'GNN Data with:\n'
- s += '- ' + str(n) + ' node type(s):\n'
- for nt in self.node_types:
- s += ' - ' + nt + '\n'
- if len(self.relation_types) == 0:
- s += '- No relation types\n'
- return s.strip()
- s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
- for i in range(n):
- for j in range(n):
- rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
- if len(rels) == 0:
- continue
- # dir = '<->' if i == j else '->'
- dir = '--'
- s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n'
- for r in rels:
- s += ' - ' + r[3] + '\n'
- return s.strip()
|