IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

104 lignes
4.1KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from .weights import init_glorot
  7. class NodeType(object):
  8. def __init__(self, name, count):
  9. self.name = name
  10. self.count = count
  11. class RelationType(object):
  12. def __init__(self, name, node_type_row, node_type_column,
  13. adjacency_matrix):
  14. self.name = name
  15. self.node_type_row = node_type_row
  16. self.node_type_column = node_type_column
  17. self.adjacency_matrix = adjacency_matrix
  18. def get_adjacency_matrix(node_type_row, node_type_column):
  19. if self.node_type_row == node_type_row and \
  20. self.node_type_column == node_type_column:
  21. return self.adjacency_matrix
  22. elif self.node_type_row == node_type_column and \
  23. self.node_type_column == node_type_row:
  24. return self.adjacency_matrix.transpose(0, 1)
  25. else:
  26. raise ValueError('Specified row/column types do not correspond to this relation')
  27. class Data(object):
  28. def __init__(self):
  29. self.node_types = []
  30. self.relation_types = defaultdict(list)
  31. # self.decoder_types = defaultdict(lambda: BilinearDecoder)
  32. # self.latent_node = []
  33. def add_node_type(self, name, count): # , latent_length):
  34. self.node_types.append(NodeType(name, count))
  35. # self.latent_node.append(init_glorot(count, latent_length))
  36. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
  37. n = len(self.node_types)
  38. if node_type_row >= n or node_type_column >= n:
  39. raise ValueError('Node type index out of bounds, add node type first')
  40. key = (node_type_row, node_type_column)
  41. if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
  42. adjacency_matrix = adjacency_matrix.to_sparse()
  43. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
  44. # _ = self.decoder_types[(node_type_row, node_type_column)]
  45. #def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  46. # if (node_type_row, node_type_column) not in self.decoder_types:
  47. # raise ValueError('Relation type not found, add relation first')
  48. # self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  49. def get_adjacency_matrices(self, node_type_row, node_type_column):
  50. # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  51. key = (node_type_row, node_type_column)
  52. if key not in self.relation_types:
  53. raise ValueError('Relation type not found')
  54. rels = self.relation_types[key]
  55. rels = list(map(lambda a: a.adjacency_matrix, rels))
  56. return rels
  57. def __repr__(self):
  58. n = len(self.node_types)
  59. if n == 0:
  60. return 'Empty GNN Data'
  61. s = ''
  62. s += 'GNN Data with:\n'
  63. s += '- ' + str(n) + ' node type(s):\n'
  64. for nt in self.node_types:
  65. s += ' - ' + nt.name + '\n'
  66. if len(self.relation_types) == 0:
  67. s += '- No relation types\n'
  68. return s.strip()
  69. n = sum(map(len, self.relation_types))
  70. s += '- ' + str(n) + ' relation type(s):\n'
  71. for i in range(n):
  72. for j in range(n):
  73. key = (i, j)
  74. if key not in self.relation_types:
  75. continue
  76. rels = self.relation_types[key]
  77. # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  78. #if len(rels) == 0:
  79. # continue
  80. # dir = '<->' if i == j else '->'
  81. dir = '--'
  82. s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
  83. #' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  84. for r in rels:
  85. s += ' - ' + r.name + '\n'
  86. return s.strip()