IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

85 líneas
3.4KB

  1. from collections import defaultdict
  2. from .decode import BilinearDecoder
  3. from .weights import init_glorot
  4. class NodeType(object):
  5. def __init__(self, name, count):
  6. self.name = name
  7. self.count = count
  8. class RelationType(object):
  9. def __init__(self, name, node_type_row, node_type_column,
  10. adjacency_matrix):
  11. self.name = name
  12. self.node_type_row = node_type_row
  13. self.node_type_column = node_type_column
  14. self.adjacency_matrix = adjacency_matrix
  15. class Data(object):
  16. def __init__(self):
  17. self.node_types = []
  18. self.relation_types = defaultdict(list)
  19. # self.decoder_types = defaultdict(lambda: BilinearDecoder)
  20. # self.latent_node = []
  21. def add_node_type(self, name, count): # , latent_length):
  22. self.node_types.append(NodeType(name, count))
  23. # self.latent_node.append(init_glorot(count, latent_length))
  24. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
  25. n = len(self.node_types)
  26. if node_type_row >= n or node_type_column >= n:
  27. raise ValueError('Node type index out of bounds, add node type first')
  28. key = (node_type_row, node_type_column)
  29. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
  30. # _ = self.decoder_types[(node_type_row, node_type_column)]
  31. #def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  32. # if (node_type_row, node_type_column) not in self.decoder_types:
  33. # raise ValueError('Relation type not found, add relation first')
  34. # self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  35. def get_adjacency_matrices(self, node_type_row, node_type_column):
  36. # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  37. key = (node_type_row, node_type_column)
  38. if key not in self.relation_types:
  39. raise ValueError('Relation type not found')
  40. rels = self.relation_types[key]
  41. rels = list(map(lambda a: a.adjacency_matrix, rels))
  42. return rels
  43. def __repr__(self):
  44. n = len(self.node_types)
  45. if n == 0:
  46. return 'Empty GNN Data'
  47. s = ''
  48. s += 'GNN Data with:\n'
  49. s += '- ' + str(n) + ' node type(s):\n'
  50. for nt in self.node_types:
  51. s += ' - ' + nt.name + '\n'
  52. if len(self.relation_types) == 0:
  53. s += '- No relation types\n'
  54. return s.strip()
  55. n = sum(map(len, self.relation_types))
  56. s += '- ' + str(n) + ' relation type(s):\n'
  57. for i in range(n):
  58. for j in range(n):
  59. key = (i, j)
  60. if key not in self.relation_types:
  61. continue
  62. rels = self.relation_types[key]
  63. # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  64. #if len(rels) == 0:
  65. # continue
  66. # dir = '<->' if i == j else '->'
  67. dir = '--'
  68. s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
  69. #' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  70. for r in rels:
  71. s += ' - ' + r.name + '\n'
  72. return s.strip()