IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

96 рядки
3.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from .weights import init_glorot
  7. class NodeType(object):
  8. def __init__(self, name, count):
  9. self.name = name
  10. self.count = count
  11. class RelationType(object):
  12. def __init__(self, name, node_type_row, node_type_column,
  13. adjacency_matrix, adjacency_matrix_transposed):
  14. if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
  15. raise ValueError('adjacency_matrix_transposed has incorrect shape')
  16. self.name = name
  17. self.node_type_row = node_type_row
  18. self.node_type_column = node_type_column
  19. self.adjacency_matrix = adjacency_matrix
  20. self.adjacency_matrix_transposed = adjacency_matrix_transposed
  21. def get_adjacency_matrix(node_type_row, node_type_column):
  22. if self.node_type_row == node_type_row and \
  23. self.node_type_column == node_type_column:
  24. return self.adjacency_matrix
  25. elif self.node_type_row == node_type_column and \
  26. self.node_type_column == node_type_row:
  27. if self.adjacency_matrix_transposed:
  28. return self.adjacency_matrix_transposed
  29. else:
  30. return self.adjacency_matrix.transpose(0, 1)
  31. else:
  32. raise ValueError('Specified row/column types do not correspond to this relation')
  33. class Data(object):
  34. def __init__(self):
  35. self.node_types = []
  36. self.relation_types = defaultdict(list)
  37. def add_node_type(self, name, count): # , latent_length):
  38. self.node_types.append(NodeType(name, count))
  39. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None):
  40. n = len(self.node_types)
  41. if node_type_row >= n or node_type_column >= n:
  42. raise ValueError('Node type index out of bounds, add node type first')
  43. key = (node_type_row, node_type_column)
  44. if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
  45. adjacency_matrix = adjacency_matrix.to_sparse()
  46. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed))
  47. def get_adjacency_matrices(self, node_type_row, node_type_column):
  48. res = []
  49. for (i, j), rels in self.relation_types.items():
  50. if node_type_row not in [i, j] and node_type_column not in [i, j]:
  51. continue
  52. for r in rels:
  53. res.append(r.get_adjacency_matrix(node_type_row, node_type_column))
  54. return res
  55. def __repr__(self):
  56. n = len(self.node_types)
  57. if n == 0:
  58. return 'Empty GNN Data'
  59. s = ''
  60. s += 'GNN Data with:\n'
  61. s += '- ' + str(n) + ' node type(s):\n'
  62. for nt in self.node_types:
  63. s += ' - ' + nt.name + '\n'
  64. if len(self.relation_types) == 0:
  65. s += '- No relation types\n'
  66. return s.strip()
  67. n = sum(map(len, self.relation_types))
  68. s += '- ' + str(n) + ' relation type(s):\n'
  69. for i in range(n):
  70. for j in range(n):
  71. key = (i, j)
  72. if key not in self.relation_types:
  73. continue
  74. rels = self.relation_types[key]
  75. s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n'
  76. for r in rels:
  77. s += ' - ' + r.name + '\n'
  78. return s.strip()