IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

209 linhas
7.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any, \
  12. Type
  13. from .decode import DEDICOMDecoder, \
  14. BilinearDecoder
  15. def _equal(x: torch.Tensor, y: torch.Tensor):
  16. if x.is_sparse ^ y.is_sparse:
  17. raise ValueError('Cannot mix sparse and dense tensors')
  18. if not x.is_sparse:
  19. return (x == y)
  20. x = x.coalesce()
  21. indices_x = list(map(tuple, x.indices().transpose(0, 1)))
  22. order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
  23. y = y.coalesce()
  24. indices_y = list(map(tuple, y.indices().transpose(0, 1)))
  25. order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
  26. return (x.values()[order_x] == y.values()[order_y])
  27. @dataclass
  28. class NodeType(object):
  29. name: str
  30. count: int
  31. @dataclass
  32. class RelationType(object):
  33. name: str
  34. node_type_row: int
  35. node_type_column: int
  36. adjacency_matrix: torch.Tensor
  37. two_way: bool
  38. hints: Dict[str, Any] = field(default_factory=dict)
  39. class RelationFamily(object):
  40. def __init__(self,
  41. data: 'Data',
  42. name: str,
  43. node_type_row: int,
  44. node_type_column: int,
  45. is_symmetric: bool,
  46. decoder_class: Type) -> None:
  47. if not is_symmetric and \
  48. decoder_class != DEDICOMDecoder and \
  49. decoder_class != BilinearDecoder:
  50. raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
  51. self.data = data
  52. self.name = name
  53. self.node_type_row = node_type_row
  54. self.node_type_column = node_type_column
  55. self.is_symmetric = is_symmetric
  56. self.decoder_class = decoder_class
  57. self.relation_types = { (node_type_row, node_type_column): [],
  58. (node_type_column, node_type_row): [] }
  59. def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
  60. adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
  61. two_way: bool = True) -> None:
  62. name = str(name)
  63. node_type_row = int(node_type_row)
  64. node_type_column = int(node_type_column)
  65. if (node_type_row, node_type_column) not in self.relation_types:
  66. raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family')
  67. if node_type_row < 0 or node_type_row >= len(self.data.node_types):
  68. raise ValueError('node_type_row outside of the valid range of node types')
  69. if node_type_column < 0 or node_type_column >= len(self.data.node_types):
  70. raise ValueError('node_type_column outside of the valid range of node types')
  71. if not isinstance(adjacency_matrix, torch.Tensor):
  72. raise ValueError('adjacency_matrix must be a torch.Tensor')
  73. if adjacency_matrix_backward is not None \
  74. and not isinstance(adjacency_matrix_backward, torch.Tensor):
  75. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  76. if adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
  77. self.data.node_types[node_type_column].count):
  78. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  79. if adjacency_matrix_backward is not None and \
  80. adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
  81. self.data.node_types[node_type_row].count):
  82. raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
  83. if node_type_row == node_type_column and \
  84. adjacency_matrix_backward is not None:
  85. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  86. if self.is_symmetric and adjacency_matrix_backward is not None:
  87. raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
  88. if self.is_symmetric and node_type_row == node_type_column and \
  89. not torch.all(_equal(adjacency_matrix,
  90. adjacency_matrix.transpose(0, 1))):
  91. raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
  92. two_way = bool(two_way)
  93. if node_type_row != node_type_column and two_way:
  94. print('%d != %d' % (node_type_row, node_type_column))
  95. if adjacency_matrix_backward is None:
  96. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  97. self.relation_types[node_type_column, node_type_row].append(
  98. RelationType(name, node_type_column, node_type_row,
  99. adjacency_matrix_backward, two_way, { 'display': False }))
  100. self.relation_types[node_type_row, node_type_column].append(
  101. RelationType(name, node_type_row, node_type_column,
  102. adjacency_matrix, two_way))
  103. def node_name(self, index):
  104. return self.data.node_types[index].name
  105. def __repr__(self):
  106. s = 'Relation family %s' % self.name
  107. for (node_type_row, node_type_column), rels in self.relation_types.items():
  108. for r in rels:
  109. if 'display' in r.hints and not r.hints['display']:
  110. continue
  111. s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
  112. (self.node_name(node_type_row), self.node_name(node_type_column)))
  113. return s
  114. def repr_indented(self):
  115. s = ' - %s' % self.name
  116. for (node_type_row, node_type_column), rels in self.relation_types.items():
  117. for r in rels:
  118. if 'display' in r.hints and not r.hints['display']:
  119. continue
  120. s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
  121. (self.node_name(node_type_row), self.node_name(node_type_column)))
  122. return s
  123. class Data(object):
  124. node_types: List[NodeType]
  125. relation_types: Dict[Tuple[int, int], List[RelationType]]
  126. def __init__(self) -> None:
  127. self.node_types = []
  128. self.relation_families = []
  129. def add_node_type(self, name: str, count: int) -> None:
  130. name = str(name)
  131. count = int(count)
  132. if not name:
  133. raise ValueError('You must provide a non-empty node type name')
  134. if count <= 0:
  135. raise ValueError('You must provide a positive node count')
  136. self.node_types.append(NodeType(name, count))
  137. def add_relation_family(self, name: str, node_type_row: int,
  138. node_type_column: int, is_symmetric: bool,
  139. decoder_class: Type = DEDICOMDecoder):
  140. fam = RelationFamily(self, name, node_type_row, node_type_column,
  141. is_symmetric, decoder_class)
  142. self.relation_families.append(fam)
  143. return fam
  144. def __repr__(self):
  145. n = len(self.node_types)
  146. if n == 0:
  147. return 'Empty Icosagon Data'
  148. s = ''
  149. s += 'Icosagon Data with:\n'
  150. s += '- ' + str(n) + ' node type(s):\n'
  151. for nt in self.node_types:
  152. s += ' - ' + nt.name + '\n'
  153. if len(self.relation_families) == 0:
  154. s += '- No relation families\n'
  155. return s.strip()
  156. s += '- %d relation families:\n' % len(self.relation_families)
  157. for fam in self.relation_families:
  158. s += fam.repr_indented() + '\n'
  159. return s.strip()