IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

129 Zeilen
4.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any
  12. @dataclass
  13. class NodeType(object):
  14. name: str
  15. count: int
  16. @dataclass
  17. class RelationType(object):
  18. name: str
  19. node_type_row: int
  20. node_type_column: int
  21. adjacency_matrix: torch.Tensor
  22. hints: Dict[str, Any] = field(default_factory=dict)
  23. class Data(object):
  24. node_types: List[NodeType]
  25. relation_types: Dict[Tuple[int, int], List[RelationType]]
  26. def __init__(self) -> None:
  27. self.node_types = []
  28. self.relation_types = defaultdict(list)
  29. def add_node_type(self, name: str, count: int) -> None:
  30. name = str(name)
  31. count = int(count)
  32. if not name:
  33. raise ValueError('You must provide a non-empty node type name')
  34. if count <= 0:
  35. raise ValueError('You must provide a positive node count')
  36. self.node_types.append(NodeType(name, count))
  37. def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
  38. adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
  39. two_way: bool = True) -> None:
  40. name = str(name)
  41. node_type_row = int(node_type_row)
  42. node_type_column = int(node_type_column)
  43. if node_type_row < 0 or node_type_row > len(self.node_types):
  44. raise ValueError('node_type_row outside of the valid range of node types')
  45. if node_type_column < 0 or node_type_column > len(self.node_types):
  46. raise ValueError('node_type_column outside of the valid range of node types')
  47. if not isinstance(adjacency_matrix, torch.Tensor):
  48. raise ValueError('adjacency_matrix must be a torch.Tensor')
  49. if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
  50. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  51. if adjacency_matrix.shape != (self.node_types[node_type_row].count,
  52. self.node_types[node_type_column].count):
  53. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  54. if adjacency_matrix_backward is not None and \
  55. adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
  56. self.node_types[node_type_row].count):
  57. raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
  58. two_way = bool(two_way)
  59. if node_type_row == node_type_column and \
  60. adjacency_matrix_backward is not None:
  61. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  62. self.relation_types[node_type_row, node_type_column].append(
  63. RelationType(name, node_type_row, node_type_column,
  64. adjacency_matrix))
  65. if node_type_row != node_type_column and two_way:
  66. hints = { 'display': False }
  67. if adjacency_matrix_backward is None:
  68. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  69. hints['symmetric'] = True
  70. self.relation_types[node_type_column, node_type_row].append(
  71. RelationType(name, node_type_column, node_type_row,
  72. adjacency_matrix_backward, hints))
  73. def __repr__(self):
  74. n = len(self.node_types)
  75. if n == 0:
  76. return 'Empty Icosagon Data'
  77. s = ''
  78. s += 'Icosagon Data with:\n'
  79. s += '- ' + str(n) + ' node type(s):\n'
  80. for nt in self.node_types:
  81. s += ' - ' + nt.name + '\n'
  82. if len(self.relation_types) == 0:
  83. s += '- No relation types\n'
  84. return s.strip()
  85. s_1 = ''
  86. count = 0
  87. for node_type_row in range(n):
  88. for node_type_column in range(n):
  89. if (node_type_row, node_type_column) not in self.relation_types:
  90. continue
  91. s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \
  92. self.node_types[node_type_column].name + ':\n'
  93. for r in self.relation_types[node_type_row, node_type_column]:
  94. if not r.hints.get('display', True):
  95. continue
  96. s_1 += ' - ' + r.name + '\n'
  97. count += 1
  98. s += '- %d relation type(s):\n' % count
  99. s += s_1
  100. return s.strip()