IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
8.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any, \
  12. Type
  13. from .decode import DEDICOMDecoder, \
  14. BilinearDecoder
  15. import numpy as np
  16. def _equal(x: torch.Tensor, y: torch.Tensor):
  17. if x.is_sparse ^ y.is_sparse:
  18. raise ValueError('Cannot mix sparse and dense tensors')
  19. if not x.is_sparse:
  20. return (x == y)
  21. # if x.shape != y.shape:
  22. # return torch.tensor(0, dtype=torch.uint8)
  23. return ((x - y).coalesce().values() == 0)
  24. x = x.coalesce()
  25. indices_x = np.empty(x.indices().shape[1], dtype=np.object)
  26. indices_x[:] = list(map(tuple, x.indices().transpose(0, 1)))
  27. order_x = np.argsort(indices_x)
  28. #order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
  29. y = y.coalesce()
  30. indices_y = np.empty(y.indices().shape[1], dtype=np.object)
  31. indices_y[:] = list(map(tuple, y.indices().transpose(0, 1)))
  32. order_y = np.argsort(indices_y)
  33. # order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
  34. # print(indices_x.shape, indices_y.shape)
  35. if not len(indices_x) == len(indices_y):
  36. return torch.tensor(0, dtype=torch.uint8)
  37. if not np.all(indices_x[order_x] == indices_y[order_y]):
  38. return torch.tensor(0, dtype=torch.uint8)
  39. return (x.values()[order_x] == y.values()[order_y])
  40. @dataclass
  41. class NodeType(object):
  42. name: str
  43. count: int
  44. @dataclass
  45. class RelationTypeBase(object):
  46. name: str
  47. node_type_row: int
  48. node_type_column: int
  49. adjacency_matrix: torch.Tensor
  50. adjacency_matrix_backward: torch.Tensor
  51. @dataclass
  52. class RelationType(RelationTypeBase):
  53. pass
  54. @dataclass
  55. class RelationFamilyBase(object):
  56. data: 'Data'
  57. name: str
  58. node_type_row: int
  59. node_type_column: int
  60. is_symmetric: bool
  61. decoder_class: Type
  62. @dataclass
  63. class RelationFamily(RelationFamilyBase):
  64. relation_types: List[RelationType] = None
  65. def __post_init__(self) -> None:
  66. if not self.is_symmetric and \
  67. self.decoder_class != DEDICOMDecoder and \
  68. self.decoder_class != BilinearDecoder:
  69. raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
  70. self.relation_types = []
  71. def add_relation_type(self,
  72. name: str, adjacency_matrix: torch.Tensor,
  73. adjacency_matrix_backward: torch.Tensor = None) -> None:
  74. name = str(name)
  75. node_type_row = self.node_type_row
  76. node_type_column = self.node_type_column
  77. if adjacency_matrix is None and adjacency_matrix_backward is None:
  78. raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
  79. if adjacency_matrix is not None and \
  80. not isinstance(adjacency_matrix, torch.Tensor):
  81. raise ValueError('adjacency_matrix must be a torch.Tensor')
  82. if adjacency_matrix_backward is not None \
  83. and not isinstance(adjacency_matrix_backward, torch.Tensor):
  84. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  85. if adjacency_matrix is not None and \
  86. adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
  87. self.data.node_types[node_type_column].count):
  88. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  89. if adjacency_matrix_backward is not None and \
  90. adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
  91. self.data.node_types[node_type_row].count):
  92. raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
  93. if node_type_row == node_type_column and \
  94. adjacency_matrix_backward is not None:
  95. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  96. if self.is_symmetric and adjacency_matrix_backward is not None:
  97. raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
  98. if self.is_symmetric and node_type_row == node_type_column and \
  99. not torch.all(_equal(adjacency_matrix,
  100. adjacency_matrix.transpose(0, 1))):
  101. raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
  102. if not self.is_symmetric and node_type_row != node_type_column and \
  103. adjacency_matrix_backward is None:
  104. raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None')
  105. if self.is_symmetric and node_type_row != node_type_column:
  106. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  107. self.relation_types.append(RelationType(name,
  108. node_type_row, node_type_column,
  109. adjacency_matrix, adjacency_matrix_backward))
  110. def node_name(self, index):
  111. return self.data.node_types[index].name
  112. def __repr__(self):
  113. s = 'Relation family %s' % self.name
  114. for r in self.relation_types:
  115. s += '\n - %s%s' % (r.name, ' (two-way)' \
  116. if (r.adjacency_matrix is not None \
  117. and r.adjacency_matrix_backward is not None) \
  118. or self.node_type_row == self.node_type_column \
  119. else '%s <- %s' % (self.node_name(self.node_type_row),
  120. self.node_name(self.node_type_column)))
  121. return s
  122. def repr_indented(self):
  123. s = ' - %s' % self.name
  124. for r in self.relation_types:
  125. s += '\n - %s%s' % (r.name, ' (two-way)' \
  126. if (r.adjacency_matrix is not None \
  127. and r.adjacency_matrix_backward is not None) \
  128. or self.node_type_row == self.node_type_column \
  129. else '%s <- %s' % (self.node_name(self.node_type_row),
  130. self.node_name(self.node_type_column)))
  131. return s
  132. class Data(object):
  133. node_types: List[NodeType]
  134. relation_families: List[RelationFamily]
  135. def __init__(self) -> None:
  136. self.node_types = []
  137. self.relation_families = []
  138. def add_node_type(self, name: str, count: int) -> None:
  139. name = str(name)
  140. count = int(count)
  141. if not name:
  142. raise ValueError('You must provide a non-empty node type name')
  143. if count <= 0:
  144. raise ValueError('You must provide a positive node count')
  145. self.node_types.append(NodeType(name, count))
  146. def add_relation_family(self, name: str, node_type_row: int,
  147. node_type_column: int, is_symmetric: bool,
  148. decoder_class: Type = DEDICOMDecoder):
  149. name = str(name)
  150. node_type_row = int(node_type_row)
  151. node_type_column = int(node_type_column)
  152. is_symmetric = bool(is_symmetric)
  153. if node_type_row < 0 or node_type_row >= len(self.node_types):
  154. raise ValueError('node_type_row outside of the valid range of node types')
  155. if node_type_column < 0 or node_type_column >= len(self.node_types):
  156. raise ValueError('node_type_column outside of the valid range of node types')
  157. fam = RelationFamily(self, name, node_type_row, node_type_column,
  158. is_symmetric, decoder_class)
  159. self.relation_families.append(fam)
  160. return fam
  161. def __repr__(self):
  162. n = len(self.node_types)
  163. if n == 0:
  164. return 'Empty Icosagon Data'
  165. s = ''
  166. s += 'Icosagon Data with:\n'
  167. s += '- ' + str(n) + ' node type(s):\n'
  168. for nt in self.node_types:
  169. s += ' - ' + nt.name + '\n'
  170. if len(self.relation_families) == 0:
  171. s += '- No relation families\n'
  172. return s.strip()
  173. s += '- %d relation families:\n' % len(self.relation_families)
  174. for fam in self.relation_families:
  175. s += fam.repr_indented() + '\n'
  176. return s.strip()