IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

97 lines
3.9KB

  1. from collections import defaultdict
  2. from .decode import BilinearDecoder
  3. from .weights import init_glorot
  4. class NodeType(object):
  5. def __init__(self, name, count):
  6. self.name = name
  7. self.count = count
  8. class RelationType(object):
  9. def __init__(self, name, node_type_row, node_type_column,
  10. adjacency_matrix):
  11. self.name = name
  12. self.node_type_row = node_type_row
  13. self.node_type_column = node_type_column
  14. self.adjacency_matrix = adjacency_matrix
  15. def get_adjacency_matrix(node_type_row, node_type_column):
  16. if self.node_type_row == node_type_row and \
  17. self.node_type_column == node_type_column:
  18. return self.adjacency_matrix
  19. elif self.node_type_row == node_type_column and \
  20. self.node_type_column == node_type_row:
  21. return self.adjacency_matrix.transpose(0, 1)
  22. else:
  23. raise ValueError('Specified row/column types do not correspond to this relation')
  24. class Data(object):
  25. def __init__(self):
  26. self.node_types = []
  27. self.relation_types = defaultdict(list)
  28. # self.decoder_types = defaultdict(lambda: BilinearDecoder)
  29. # self.latent_node = []
  30. def add_node_type(self, name, count): # , latent_length):
  31. self.node_types.append(NodeType(name, count))
  32. # self.latent_node.append(init_glorot(count, latent_length))
  33. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
  34. n = len(self.node_types)
  35. if node_type_row >= n or node_type_column >= n:
  36. raise ValueError('Node type index out of bounds, add node type first')
  37. key = (node_type_row, node_type_column)
  38. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
  39. # _ = self.decoder_types[(node_type_row, node_type_column)]
  40. #def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  41. # if (node_type_row, node_type_column) not in self.decoder_types:
  42. # raise ValueError('Relation type not found, add relation first')
  43. # self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  44. def get_adjacency_matrices(self, node_type_row, node_type_column):
  45. # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  46. key = (node_type_row, node_type_column)
  47. if key not in self.relation_types:
  48. raise ValueError('Relation type not found')
  49. rels = self.relation_types[key]
  50. rels = list(map(lambda a: a.adjacency_matrix, rels))
  51. return rels
  52. def __repr__(self):
  53. n = len(self.node_types)
  54. if n == 0:
  55. return 'Empty GNN Data'
  56. s = ''
  57. s += 'GNN Data with:\n'
  58. s += '- ' + str(n) + ' node type(s):\n'
  59. for nt in self.node_types:
  60. s += ' - ' + nt.name + '\n'
  61. if len(self.relation_types) == 0:
  62. s += '- No relation types\n'
  63. return s.strip()
  64. n = sum(map(len, self.relation_types))
  65. s += '- ' + str(n) + ' relation type(s):\n'
  66. for i in range(n):
  67. for j in range(n):
  68. key = (i, j)
  69. if key not in self.relation_types:
  70. continue
  71. rels = self.relation_types[key]
  72. # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  73. #if len(rels) == 0:
  74. # continue
  75. # dir = '<->' if i == j else '->'
  76. dir = '--'
  77. s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
  78. #' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  79. for r in rels:
  80. s += ' - ' + r.name + '\n'
  81. return s.strip()