IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

126 рядки
4.5KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple
  11. @dataclass
  12. class NodeType(object):
  13. name: str
  14. count: int
  15. @dataclass
  16. class RelationType(object):
  17. name: str
  18. node_type_row: int
  19. node_type_column: int
  20. adjacency_matrix: torch.Tensor
  21. is_autogenerated: bool = False
  22. class Data(object):
  23. node_types: List[NodeType]
  24. relation_types: Dict[Tuple[int, int], List[RelationType]]
  25. def __init__(self) -> None:
  26. self.node_types = []
  27. self.relation_types = defaultdict(list)
  28. def add_node_type(self, name: str, count: int) -> None:
  29. name = str(name)
  30. count = int(count)
  31. if not name:
  32. raise ValueError('You must provide a non-empty node type name')
  33. if count <= 0:
  34. raise ValueError('You must provide a positive node count')
  35. self.node_types.append(NodeType(name, count))
  36. def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
  37. adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
  38. two_way: bool = True) -> None:
  39. name = str(name)
  40. node_type_row = int(node_type_row)
  41. node_type_column = int(node_type_column)
  42. if node_type_row < 0 or node_type_row > len(self.node_types):
  43. raise ValueError('node_type_row outside of the valid range of node types')
  44. if node_type_column < 0 or node_type_column > len(self.node_types):
  45. raise ValueError('node_type_column outside of the valid range of node types')
  46. if not isinstance(adjacency_matrix, torch.Tensor):
  47. raise ValueError('adjacency_matrix must be a torch.Tensor')
  48. if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
  49. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  50. if adjacency_matrix.shape != (self.node_types[node_type_row].count,
  51. self.node_types[node_type_column].count):
  52. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  53. if adjacency_matrix_backward is not None and \
  54. adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
  55. self.node_types[node_type_row].count):
  56. raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
  57. two_way = bool(two_way)
  58. if node_type_row == node_type_column and \
  59. adjacency_matrix_backward is not None:
  60. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  61. self.relation_types[node_type_row, node_type_column].append(
  62. RelationType(name, node_type_row, node_type_column,
  63. adjacency_matrix, False))
  64. if node_type_row != node_type_column and two_way:
  65. if adjacency_matrix_backward is None:
  66. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  67. self.relation_types[node_type_column, node_type_row].append(
  68. RelationType(name, node_type_column, node_type_row,
  69. adjacency_matrix_backward, True))
  70. def __repr__(self):
  71. n = len(self.node_types)
  72. if n == 0:
  73. return 'Empty Icosagon Data'
  74. s = ''
  75. s += 'Icosagon Data with:\n'
  76. s += '- ' + str(n) + ' node type(s):\n'
  77. for nt in self.node_types:
  78. s += ' - ' + nt.name + '\n'
  79. if len(self.relation_types) == 0:
  80. s += '- No relation types\n'
  81. return s.strip()
  82. s_1 = ''
  83. count = 0
  84. for node_type_row in range(n):
  85. for node_type_column in range(n):
  86. if (node_type_row, node_type_column) not in self.relation_types:
  87. continue
  88. s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \
  89. self.node_types[node_type_column].name + ':\n'
  90. for r in self.relation_types[node_type_row, node_type_column]:
  91. if r.is_autogenerated:
  92. continue
  93. s_1 += ' - ' + r.name + '\n'
  94. count += 1
  95. s += '- %d relation type(s):\n' % count
  96. s += s_1
  97. return s.strip()