IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

97 lignes
3.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from ..weights import init_glorot
  7. class NodeType(object):
  8. def __init__(self, name, count):
  9. self.name = name
  10. self.count = count
  11. class RelationType(object):
  12. def __init__(self, name, node_type_row, node_type_column,
  13. adjacency_matrix, adjacency_matrix_transposed):
  14. if adjacency_matrix_transposed is not None and \
  15. adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
  16. raise ValueError('adjacency_matrix_transposed has incorrect shape')
  17. self.name = name
  18. self.node_type_row = node_type_row
  19. self.node_type_column = node_type_column
  20. self.adjacency_matrix = adjacency_matrix
  21. self.adjacency_matrix_transposed = adjacency_matrix_transposed
  22. def get_adjacency_matrix(node_type_row, node_type_column):
  23. if self.node_type_row == node_type_row and \
  24. self.node_type_column == node_type_column:
  25. return self.adjacency_matrix
  26. elif self.node_type_row == node_type_column and \
  27. self.node_type_column == node_type_row:
  28. if self.adjacency_matrix_transposed:
  29. return self.adjacency_matrix_transposed
  30. else:
  31. return self.adjacency_matrix.transpose(0, 1)
  32. else:
  33. raise ValueError('Specified row/column types do not correspond to this relation')
  34. class Data(object):
  35. def __init__(self):
  36. self.node_types = []
  37. self.relation_types = defaultdict(list)
  38. def add_node_type(self, name, count): # , latent_length):
  39. self.node_types.append(NodeType(name, count))
  40. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None):
  41. n = len(self.node_types)
  42. if node_type_row >= n or node_type_column >= n:
  43. raise ValueError('Node type index out of bounds, add node type first')
  44. key = (node_type_row, node_type_column)
  45. if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
  46. adjacency_matrix = adjacency_matrix.to_sparse()
  47. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed))
  48. def get_adjacency_matrices(self, node_type_row, node_type_column):
  49. res = []
  50. for (i, j), rels in self.relation_types.items():
  51. if node_type_row not in [i, j] and node_type_column not in [i, j]:
  52. continue
  53. for r in rels:
  54. res.append(r.get_adjacency_matrix(node_type_row, node_type_column))
  55. return res
  56. def __repr__(self):
  57. n = len(self.node_types)
  58. if n == 0:
  59. return 'Empty GNN Data'
  60. s = ''
  61. s += 'GNN Data with:\n'
  62. s += '- ' + str(n) + ' node type(s):\n'
  63. for nt in self.node_types:
  64. s += ' - ' + nt.name + '\n'
  65. if len(self.relation_types) == 0:
  66. s += '- No relation types\n'
  67. return s.strip()
  68. n = sum(map(len, self.relation_types))
  69. s += '- ' + str(n) + ' relation type(s):\n'
  70. for i in range(n):
  71. for j in range(n):
  72. key = (i, j)
  73. if key not in self.relation_types:
  74. continue
  75. rels = self.relation_types[key]
  76. s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n'
  77. for r in rels:
  78. s += ' - ' + r.name + '\n'
  79. return s.strip()