|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- from collections import defaultdict
- from ..weights import init_glorot
-
-
- class NodeType(object):
- def __init__(self, name, count):
- self.name = name
- self.count = count
-
-
- class RelationType(object):
- def __init__(self, name, node_type_row, node_type_column,
- adjacency_matrix, adjacency_matrix_transposed):
-
- if adjacency_matrix_transposed is not None and \
- adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
- raise ValueError('adjacency_matrix_transposed has incorrect shape')
-
- self.name = name
- self.node_type_row = node_type_row
- self.node_type_column = node_type_column
- self.adjacency_matrix = adjacency_matrix
- self.adjacency_matrix_transposed = adjacency_matrix_transposed
-
- def get_adjacency_matrix(node_type_row, node_type_column):
- if self.node_type_row == node_type_row and \
- self.node_type_column == node_type_column:
- return self.adjacency_matrix
-
- elif self.node_type_row == node_type_column and \
- self.node_type_column == node_type_row:
- if self.adjacency_matrix_transposed:
- return self.adjacency_matrix_transposed
- else:
- return self.adjacency_matrix.transpose(0, 1)
-
- else:
- raise ValueError('Specified row/column types do not correspond to this relation')
-
-
- class Data(object):
- def __init__(self):
- self.node_types = []
- self.relation_types = defaultdict(list)
-
- def add_node_type(self, name, count): # , latent_length):
- self.node_types.append(NodeType(name, count))
-
- def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None):
- n = len(self.node_types)
- if node_type_row >= n or node_type_column >= n:
- raise ValueError('Node type index out of bounds, add node type first')
- key = (node_type_row, node_type_column)
- if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
- adjacency_matrix = adjacency_matrix.to_sparse()
- self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed))
-
- def get_adjacency_matrices(self, node_type_row, node_type_column):
- res = []
- for (i, j), rels in self.relation_types.items():
- if node_type_row not in [i, j] and node_type_column not in [i, j]:
- continue
- for r in rels:
- res.append(r.get_adjacency_matrix(node_type_row, node_type_column))
- return res
-
- def __repr__(self):
- n = len(self.node_types)
- if n == 0:
- return 'Empty GNN Data'
- s = ''
- s += 'GNN Data with:\n'
- s += '- ' + str(n) + ' node type(s):\n'
- for nt in self.node_types:
- s += ' - ' + nt.name + '\n'
- if len(self.relation_types) == 0:
- s += '- No relation types\n'
- return s.strip()
- n = sum(map(len, self.relation_types))
- s += '- ' + str(n) + ' relation type(s):\n'
- for i in range(n):
- for j in range(n):
- key = (i, j)
- if key not in self.relation_types:
- continue
- rels = self.relation_types[key]
- s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n'
- for r in rels:
- s += ' - ' + r.name + '\n'
- return s.strip()
|