IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 line
4.6KB

  1. from collections import defaultdict
  2. from dataclasses import dataclass
  3. import torch
  4. @dataclass
  5. class NodeType(object):
  6. name: str
  7. count: int
  8. @dataclass
  9. class RelationType(object):
  10. name: str
  11. node_type_row: int
  12. node_type_column: int
  13. adjacency_matrix: torch.Tensor
  14. is_autogenerated: bool
  15. class Data(object):
  16. def __init__(self) -> None:
  17. self.node_types = []
  18. self.relation_types = defaultdict(lambda: defaultdict(list))
  19. def add_node_type(self, name: str, count: int) -> None:
  20. name = str(name)
  21. count = int(count)
  22. if not name:
  23. raise ValueError('You must provide a non-empty node type name')
  24. if count <= 0:
  25. raise ValueError('You must provide a positive node count')
  26. self.node_types.append(NodeType(name, count))
  27. def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
  28. adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
  29. two_way: bool = True) -> None:
  30. name = str(name)
  31. node_type_row = int(node_type_row)
  32. node_type_column = int(node_type_column)
  33. if node_type_row < 0 or node_type_row > len(self.node_types):
  34. raise ValueError('node_type_row outside of the valid range of node types')
  35. if node_type_column < 0 or node_type_column > len(self.node_types):
  36. raise ValueError('node_type_column outside of the valid range of node types')
  37. if not isinstance(adjacency_matrix, torch.Tensor):
  38. raise ValueError('adjacency_matrix must be a torch.Tensor')
  39. if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
  40. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  41. if adjacency_matrix.shape != (self.node_types[node_type_row].count,
  42. self.node_types[node_type_column].count):
  43. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  44. if adjacency_matrix_backward is not None and \
  45. adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
  46. self.node_types[node_type_row].count):
  47. raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
  48. two_way = bool(two_way)
  49. if node_type_row == node_type_column and \
  50. adjacency_matrix_backward is not None:
  51. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  52. self.relation_types[node_type_row][node_type_column].append(
  53. RelationType(name, node_type_row, node_type_column,
  54. adjacency_matrix, False))
  55. if node_type_row != node_type_column and two_way:
  56. if adjacency_matrix_backward is None:
  57. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  58. self.relation_types[node_type_column][node_type_row].append(
  59. RelationType(name, node_type_column, node_type_row,
  60. adjacency_matrix_backward, True))
  61. def __repr__(self):
  62. n = len(self.node_types)
  63. if n == 0:
  64. return 'Empty Icosagon Data'
  65. s = ''
  66. s += 'Icosagon Data with:\n'
  67. s += '- ' + str(n) + ' node type(s):\n'
  68. for nt in self.node_types:
  69. s += ' - ' + nt.name + '\n'
  70. if len(self.relation_types) == 0:
  71. s += '- No relation types\n'
  72. return s.strip()
  73. s_1 = ''
  74. count = 0
  75. for i in range(n):
  76. for j in range(n):
  77. if i not in self.relation_types or \
  78. j not in self.relation_types[i]:
  79. continue
  80. s_1 += ' - ' + self.node_types[i].name + ' -- ' + \
  81. self.node_types[j].name + ':\n'
  82. for r in self.relation_types[i][j]:
  83. if r.is_autogenerated:
  84. continue
  85. s_1 += ' - ' + r.name + '\n'
  86. count += 1
  87. s += '- %d relation type(s):\n' % count
  88. s += s_1
  89. return s.strip()
  90. # n = sum(map(len, self.relation_types))
  91. #
  92. # for i in range(n):
  93. # for j in range(n):
  94. # key = (i, j)
  95. # if key not in self.relation_types:
  96. # continue
  97. # rels = self.relation_types[key]
  98. #
  99. # for r in rels:
  100. #
  101. # return s.strip()