IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

210 lines
7.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any, \
  12. Type
  13. from .decode import DEDICOMDecoder, \
  14. BilinearDecoder
  15. import numpy as np
  16. def _equal(x: torch.Tensor, y: torch.Tensor):
  17. if x.is_sparse ^ y.is_sparse:
  18. raise ValueError('Cannot mix sparse and dense tensors')
  19. if not x.is_sparse:
  20. return (x == y)
  21. return ((x - y).coalesce().values() == 0)
  22. @dataclass
  23. class NodeType(object):
  24. name: str
  25. count: int
  26. @dataclass
  27. class RelationTypeBase(object):
  28. name: str
  29. node_type_row: int
  30. node_type_column: int
  31. adjacency_matrix: torch.Tensor
  32. adjacency_matrix_backward: torch.Tensor
  33. @dataclass
  34. class RelationType(RelationTypeBase):
  35. pass
  36. @dataclass
  37. class RelationFamilyBase(object):
  38. data: 'Data'
  39. name: str
  40. node_type_row: int
  41. node_type_column: int
  42. is_symmetric: bool
  43. decoder_class: Type
  44. @dataclass
  45. class RelationFamily(RelationFamilyBase):
  46. relation_types: List[RelationType] = None
  47. def __post_init__(self) -> None:
  48. if not self.is_symmetric and \
  49. self.decoder_class != DEDICOMDecoder and \
  50. self.decoder_class != BilinearDecoder:
  51. raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
  52. self.relation_types = []
  53. def add_relation_type(self,
  54. name: str, adjacency_matrix: torch.Tensor,
  55. adjacency_matrix_backward: torch.Tensor = None) -> None:
  56. name = str(name)
  57. node_type_row = self.node_type_row
  58. node_type_column = self.node_type_column
  59. if adjacency_matrix is None and adjacency_matrix_backward is None:
  60. raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
  61. if adjacency_matrix is not None and \
  62. not isinstance(adjacency_matrix, torch.Tensor):
  63. raise ValueError('adjacency_matrix must be a torch.Tensor')
  64. if adjacency_matrix_backward is not None \
  65. and not isinstance(adjacency_matrix_backward, torch.Tensor):
  66. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  67. if adjacency_matrix is not None and \
  68. adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
  69. self.data.node_types[node_type_column].count):
  70. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  71. if adjacency_matrix_backward is not None and \
  72. adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
  73. self.data.node_types[node_type_row].count):
  74. raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
  75. if node_type_row == node_type_column and \
  76. adjacency_matrix_backward is not None:
  77. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  78. if self.is_symmetric and adjacency_matrix_backward is not None:
  79. raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
  80. if self.is_symmetric and node_type_row == node_type_column and \
  81. not torch.all(_equal(adjacency_matrix,
  82. adjacency_matrix.transpose(0, 1))):
  83. raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
  84. if not self.is_symmetric and node_type_row != node_type_column and \
  85. adjacency_matrix_backward is None:
  86. raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None')
  87. if self.is_symmetric and node_type_row != node_type_column:
  88. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  89. self.relation_types.append(RelationType(name,
  90. node_type_row, node_type_column,
  91. adjacency_matrix, adjacency_matrix_backward))
  92. def node_name(self, index):
  93. return self.data.node_types[index].name
  94. def __repr__(self):
  95. s = 'Relation family %s' % self.name
  96. for r in self.relation_types:
  97. s += '\n - %s%s' % (r.name, ' (two-way)' \
  98. if (r.adjacency_matrix is not None \
  99. and r.adjacency_matrix_backward is not None) \
  100. or self.node_type_row == self.node_type_column \
  101. else '%s <- %s' % (self.node_name(self.node_type_row),
  102. self.node_name(self.node_type_column)))
  103. return s
  104. def repr_indented(self):
  105. s = ' - %s' % self.name
  106. for r in self.relation_types:
  107. s += '\n - %s%s' % (r.name, ' (two-way)' \
  108. if (r.adjacency_matrix is not None \
  109. and r.adjacency_matrix_backward is not None) \
  110. or self.node_type_row == self.node_type_column \
  111. else '%s <- %s' % (self.node_name(self.node_type_row),
  112. self.node_name(self.node_type_column)))
  113. return s
  114. class Data(object):
  115. node_types: List[NodeType]
  116. relation_families: List[RelationFamily]
  117. def __init__(self) -> None:
  118. self.node_types = []
  119. self.relation_families = []
  120. def add_node_type(self, name: str, count: int) -> None:
  121. name = str(name)
  122. count = int(count)
  123. if not name:
  124. raise ValueError('You must provide a non-empty node type name')
  125. if count <= 0:
  126. raise ValueError('You must provide a positive node count')
  127. self.node_types.append(NodeType(name, count))
  128. def add_relation_family(self, name: str, node_type_row: int,
  129. node_type_column: int, is_symmetric: bool,
  130. decoder_class: Type = DEDICOMDecoder):
  131. name = str(name)
  132. node_type_row = int(node_type_row)
  133. node_type_column = int(node_type_column)
  134. is_symmetric = bool(is_symmetric)
  135. if node_type_row < 0 or node_type_row >= len(self.node_types):
  136. raise ValueError('node_type_row outside of the valid range of node types')
  137. if node_type_column < 0 or node_type_column >= len(self.node_types):
  138. raise ValueError('node_type_column outside of the valid range of node types')
  139. fam = RelationFamily(self, name, node_type_row, node_type_column,
  140. is_symmetric, decoder_class)
  141. self.relation_families.append(fam)
  142. return fam
  143. def __repr__(self):
  144. n = len(self.node_types)
  145. if n == 0:
  146. return 'Empty Icosagon Data'
  147. s = ''
  148. s += 'Icosagon Data with:\n'
  149. s += '- ' + str(n) + ' node type(s):\n'
  150. for nt in self.node_types:
  151. s += ' - ' + nt.name + '\n'
  152. if len(self.relation_families) == 0:
  153. s += '- No relation families\n'
  154. return s.strip()
  155. s += '- %d relation families:\n' % len(self.relation_families)
  156. for fam in self.relation_families:
  157. s += fam.repr_indented() + '\n'
  158. return s.strip()