IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

85 lignes
3.4KB

  1. from collections import defaultdict
  2. from .decode import BilinearDecoder
  3. from .weights import init_glorot
  4. class NodeType(object):
  5. def __init__(self, name, count):
  6. self.name = name
  7. self.count = count
  8. class RelationType(object):
  9. def __init__(self, name, node_type_row, node_type_column,
  10. adjacency_matrix):
  11. self.name = name
  12. self.node_type_row = node_type_row
  13. self.node_type_column = node_type_column
  14. self.adjacency_matrix = adjacency_matrix
  15. class Data(object):
  16. def __init__(self):
  17. self.node_types = []
  18. self.relation_types = defaultdict(list)
  19. # self.decoder_types = defaultdict(lambda: BilinearDecoder)
  20. # self.latent_node = []
  21. def add_node_type(self, name, count): # , latent_length):
  22. self.node_types.append(NodeType(name, count))
  23. # self.latent_node.append(init_glorot(count, latent_length))
  24. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
  25. n = len(self.node_types)
  26. if node_type_row >= n or node_type_column >= n:
  27. raise ValueError('Node type index out of bounds, add node type first')
  28. key = (node_type_row, node_type_column)
  29. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
  30. # _ = self.decoder_types[(node_type_row, node_type_column)]
  31. #def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  32. # if (node_type_row, node_type_column) not in self.decoder_types:
  33. # raise ValueError('Relation type not found, add relation first')
  34. # self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  35. def get_adjacency_matrices(self, node_type_row, node_type_column):
  36. # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  37. key = (node_type_row, node_type_column)
  38. if key not in self.relation_types:
  39. raise ValueError('Relation type not found')
  40. rels = self.relation_types[key]
  41. rels = list(map(lambda a: a.adjacency_matrix, rels))
  42. return rels
  43. def __repr__(self):
  44. n = len(self.node_types)
  45. if n == 0:
  46. return 'Empty GNN Data'
  47. s = ''
  48. s += 'GNN Data with:\n'
  49. s += '- ' + str(n) + ' node type(s):\n'
  50. for nt in self.node_types:
  51. s += ' - ' + nt.name + '\n'
  52. if len(self.relation_types) == 0:
  53. s += '- No relation types\n'
  54. return s.strip()
  55. n = sum(map(len, self.relation_types))
  56. s += '- ' + str(n) + ' relation type(s):\n'
  57. for i in range(n):
  58. for j in range(n):
  59. key = (i, j)
  60. if key not in self.relation_types:
  61. continue
  62. rels = self.relation_types[key]
  63. # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  64. #if len(rels) == 0:
  65. # continue
  66. # dir = '<->' if i == j else '->'
  67. dir = '--'
  68. s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
  69. #' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  70. for r in rels:
  71. s += ' - ' + r.name + '\n'
  72. return s.strip()