IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
7.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any, \
  12. Type
  13. from .decode import DEDICOMDecoder, \
  14. BilinearDecoder
  15. @dataclass
  16. class NodeType(object):
  17. name: str
  18. count: int
  19. @dataclass
  20. class RelationType(object):
  21. name: str
  22. node_type_row: int
  23. node_type_column: int
  24. adjacency_matrix: torch.Tensor
  25. two_way: bool
  26. hints: Dict[str, Any] = field(default_factory=dict)
  27. class RelationFamily(object):
  28. def __init__(self,
  29. data: 'Data',
  30. name: str,
  31. node_type_row: int,
  32. node_type_column: int,
  33. is_symmetric: bool,
  34. decoder_class: Type) -> None:
  35. if not is_symmetric and \
  36. decoder_class != DEDICOMDecoder and \
  37. decoder_class != BilinearDecoder:
  38. raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
  39. self.data = data
  40. self.name = name
  41. self.node_type_row = node_type_row
  42. self.node_type_column = node_type_column
  43. self.is_symmetric = is_symmetric
  44. self.decoder_class = decoder_class
  45. self.relation_types = { (node_type_row, node_type_column): [],
  46. (node_type_column, node_type_row): [] }
  47. def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
  48. adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
  49. two_way: bool = True) -> None:
  50. name = str(name)
  51. node_type_row = int(node_type_row)
  52. node_type_column = int(node_type_column)
  53. if (node_type_row, node_type_column) not in self.relation_types:
  54. raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family')
  55. if node_type_row < 0 or node_type_row >= len(self.data.node_types):
  56. raise ValueError('node_type_row outside of the valid range of node types')
  57. if node_type_column < 0 or node_type_column >= len(self.data.node_types):
  58. raise ValueError('node_type_column outside of the valid range of node types')
  59. if not isinstance(adjacency_matrix, torch.Tensor):
  60. raise ValueError('adjacency_matrix must be a torch.Tensor')
  61. if adjacency_matrix_backward is not None \
  62. and not isinstance(adjacency_matrix_backward, torch.Tensor):
  63. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  64. if adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
  65. self.data.node_types[node_type_column].count):
  66. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  67. if adjacency_matrix_backward is not None and \
  68. adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
  69. self.data.node_types[node_type_row].count):
  70. raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
  71. if node_type_row == node_type_column and \
  72. adjacency_matrix_backward is not None:
  73. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  74. if self.is_symmetric and adjacency_matrix_backward is not None:
  75. raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
  76. if self.is_symmetric and node_type_row == node_type_column and \
  77. not torch.all(adjacency_matrix == adjacency_matrix.transpose(0, 1)):
  78. raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
  79. two_way = bool(two_way)
  80. if node_type_row != node_type_column and two_way:
  81. print('%d != %d' % (node_type_row, node_type_column))
  82. if adjacency_matrix_backward is None:
  83. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  84. self.relation_types[node_type_column, node_type_row].append(
  85. RelationType(name, node_type_column, node_type_row,
  86. adjacency_matrix_backward, two_way, { 'display': False }))
  87. self.relation_types[node_type_row, node_type_column].append(
  88. RelationType(name, node_type_row, node_type_column,
  89. adjacency_matrix, two_way))
  90. def node_name(self, index):
  91. return self.data.node_types[index].name
  92. def __repr__(self):
  93. s = 'Relation family %s' % self.name
  94. for (node_type_row, node_type_column), rels in self.relation_types.items():
  95. for r in rels:
  96. if 'display' in r.hints and not r.hints['display']:
  97. continue
  98. s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
  99. (self.node_name(node_type_row), self.node_name(node_type_column)))
  100. return s
  101. def repr_indented(self):
  102. s = ' - %s' % self.name
  103. for (node_type_row, node_type_column), rels in self.relation_types.items():
  104. for r in rels:
  105. if 'display' in r.hints and not r.hints['display']:
  106. continue
  107. s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
  108. (self.node_name(node_type_row), self.node_name(node_type_column)))
  109. return s
  110. class Data(object):
  111. node_types: List[NodeType]
  112. relation_types: Dict[Tuple[int, int], List[RelationType]]
  113. def __init__(self) -> None:
  114. self.node_types = []
  115. self.relation_families = []
  116. def add_node_type(self, name: str, count: int) -> None:
  117. name = str(name)
  118. count = int(count)
  119. if not name:
  120. raise ValueError('You must provide a non-empty node type name')
  121. if count <= 0:
  122. raise ValueError('You must provide a positive node count')
  123. self.node_types.append(NodeType(name, count))
  124. def add_relation_family(self, name: str, node_type_row: int,
  125. node_type_column: int, is_symmetric: bool,
  126. decoder_class: Type = DEDICOMDecoder):
  127. fam = RelationFamily(self, name, node_type_row, node_type_column,
  128. is_symmetric, decoder_class)
  129. self.relation_families.append(fam)
  130. return fam
  131. def __repr__(self):
  132. n = len(self.node_types)
  133. if n == 0:
  134. return 'Empty Icosagon Data'
  135. s = ''
  136. s += 'Icosagon Data with:\n'
  137. s += '- ' + str(n) + ' node type(s):\n'
  138. for nt in self.node_types:
  139. s += ' - ' + nt.name + '\n'
  140. if len(self.relation_families) == 0:
  141. s += '- No relation families\n'
  142. return s.strip()
  143. s += '- %d relation families:\n' % len(self.relation_families)
  144. for fam in self.relation_families:
  145. s += fam.repr_indented() + '\n'
  146. return s.strip()