IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 line
4.1KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from .weights import init_glorot
  7. class NodeType(object):
  8. def __init__(self, name, count):
  9. self.name = name
  10. self.count = count
  11. class RelationType(object):
  12. def __init__(self, name, node_type_row, node_type_column,
  13. adjacency_matrix):
  14. self.name = name
  15. self.node_type_row = node_type_row
  16. self.node_type_column = node_type_column
  17. self.adjacency_matrix = adjacency_matrix
  18. def get_adjacency_matrix(node_type_row, node_type_column):
  19. if self.node_type_row == node_type_row and \
  20. self.node_type_column == node_type_column:
  21. return self.adjacency_matrix
  22. elif self.node_type_row == node_type_column and \
  23. self.node_type_column == node_type_row:
  24. return self.adjacency_matrix.transpose(0, 1)
  25. else:
  26. raise ValueError('Specified row/column types do not correspond to this relation')
  27. class Data(object):
  28. def __init__(self):
  29. self.node_types = []
  30. self.relation_types = defaultdict(list)
  31. # self.decoder_types = defaultdict(lambda: BilinearDecoder)
  32. # self.latent_node = []
  33. def add_node_type(self, name, count): # , latent_length):
  34. self.node_types.append(NodeType(name, count))
  35. # self.latent_node.append(init_glorot(count, latent_length))
  36. def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
  37. n = len(self.node_types)
  38. if node_type_row >= n or node_type_column >= n:
  39. raise ValueError('Node type index out of bounds, add node type first')
  40. key = (node_type_row, node_type_column)
  41. if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
  42. adjacency_matrix = adjacency_matrix.to_sparse()
  43. self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
  44. # _ = self.decoder_types[(node_type_row, node_type_column)]
  45. #def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
  46. # if (node_type_row, node_type_column) not in self.decoder_types:
  47. # raise ValueError('Relation type not found, add relation first')
  48. # self.decoder_types[(node_type_row, node_type_column)] = decoder_class
  49. def get_adjacency_matrices(self, node_type_row, node_type_column):
  50. # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
  51. key = (node_type_row, node_type_column)
  52. if key not in self.relation_types:
  53. raise ValueError('Relation type not found')
  54. rels = self.relation_types[key]
  55. rels = list(map(lambda a: a.adjacency_matrix, rels))
  56. return rels
  57. def __repr__(self):
  58. n = len(self.node_types)
  59. if n == 0:
  60. return 'Empty GNN Data'
  61. s = ''
  62. s += 'GNN Data with:\n'
  63. s += '- ' + str(n) + ' node type(s):\n'
  64. for nt in self.node_types:
  65. s += ' - ' + nt.name + '\n'
  66. if len(self.relation_types) == 0:
  67. s += '- No relation types\n'
  68. return s.strip()
  69. n = sum(map(len, self.relation_types))
  70. s += '- ' + str(n) + ' relation type(s):\n'
  71. for i in range(n):
  72. for j in range(n):
  73. key = (i, j)
  74. if key not in self.relation_types:
  75. continue
  76. rels = self.relation_types[key]
  77. # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
  78. #if len(rels) == 0:
  79. # continue
  80. # dir = '<->' if i == j else '->'
  81. dir = '--'
  82. s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
  83. #' (' + self.decoder_types[(i, j)].__name__ + '):\n'
  84. for r in rels:
  85. s += ' - ' + r.name + '\n'
  86. return s.strip()