IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

59 рядки
2.4KB

  1. from collections import defaultdict
  2. from .decode import BilinearDecoder
  3. from .weights import init_glorot
  4. class Data(object):
  5. def __init__(self):
  6. self.node_types = []
  7. self.relation_types = []
  8. self.decoder_types = defaultdict(lambda: BilinearDecoder)
  9. self.latent_node = []
  10. def add_node_type(self, name, count, latent_length):
  11. self.node_types.append(name)
  12. self.latent_node.append(init_glorot(count, latent_length))
  13. def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
  14. n = len(self.node_types)
  15. if node_type_row >= n or node_type_column >= n:
  16. raise ValueError('Node type index out of bounds, add node type first')
  17. self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
  18. _ = self.decoder_types[(node_type_row, node_type_column)]
  19. def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  20. if (node_type_row, node_type_column) not in self.decoder_types:
  21. raise ValueError('Relation type not found, add relation first')
  22. self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  23. def get_adjacency_matrices(self, node_type_row, node_type_column):
  24. rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  25. if len(rels) == 0:
  26. def __repr__(self):
  27. n = len(self.node_types)
  28. if n == 0:
  29. return 'Empty GNN Data'
  30. s = ''
  31. s += 'GNN Data with:\n'
  32. s += '- ' + str(n) + ' node type(s):\n'
  33. for nt in self.node_types:
  34. s += ' - ' + nt + '\n'
  35. if len(self.relation_types) == 0:
  36. s += '- No relation types\n'
  37. return s.strip()
  38. s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
  39. for i in range(n):
  40. for j in range(n):
  41. rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  42. if len(rels) == 0:
  43. continue
  44. # dir = '<->' if i == j else '->'
  45. dir = '--'
  46. s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  47. for r in rels:
  48. s += ' - ' + r[3] + '\n'
  49. return s.strip()