IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.py 7.2KB

4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. import torch
  8. from typing import List, \
  9. Dict, \
  10. Tuple, \
  11. Any, \
  12. Type
  13. from .decode import DEDICOMDecoder, \
  14. BilinearDecoder
  15. import numpy as np
  16. def _equal(x: torch.Tensor, y: torch.Tensor):
  17. if x.is_sparse ^ y.is_sparse:
  18. raise ValueError('Cannot mix sparse and dense tensors')
  19. if not x.is_sparse:
  20. return (x == y)
  21. return ((x - y).coalesce().values() == 0)
  22. @dataclass
  23. class NodeType(object):
  24. name: str
  25. count: int
  26. @dataclass
  27. class RelationTypeBase(object):
  28. name: str
  29. node_type_row: int
  30. node_type_column: int
  31. adjacency_matrix: torch.Tensor
  32. adjacency_matrix_backward: torch.Tensor
  33. @dataclass
  34. class RelationType(RelationTypeBase):
  35. pass
  36. @dataclass
  37. class RelationFamilyBase(object):
  38. data: 'Data'
  39. name: str
  40. node_type_row: int
  41. node_type_column: int
  42. is_symmetric: bool
  43. decoder_class: Type
  44. @dataclass
  45. class RelationFamily(RelationFamilyBase):
  46. relation_types: List[RelationType] = None
  47. def __post_init__(self) -> None:
  48. if not self.is_symmetric and \
  49. self.decoder_class != DEDICOMDecoder and \
  50. self.decoder_class != BilinearDecoder:
  51. raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
  52. self.relation_types = []
  53. def add_relation_type(self,
  54. name: str, adjacency_matrix: torch.Tensor,
  55. adjacency_matrix_backward: torch.Tensor = None) -> None:
  56. name = str(name)
  57. node_type_row = self.node_type_row
  58. node_type_column = self.node_type_column
  59. if adjacency_matrix is None and adjacency_matrix_backward is None:
  60. raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
  61. if adjacency_matrix is not None and \
  62. not isinstance(adjacency_matrix, torch.Tensor):
  63. raise ValueError('adjacency_matrix must be a torch.Tensor')
  64. if adjacency_matrix_backward is not None \
  65. and not isinstance(adjacency_matrix_backward, torch.Tensor):
  66. raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
  67. if adjacency_matrix is not None and \
  68. adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
  69. self.data.node_types[node_type_column].count):
  70. raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
  71. if adjacency_matrix_backward is not None and \
  72. adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
  73. self.data.node_types[node_type_row].count):
  74. raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
  75. if node_type_row == node_type_column and \
  76. adjacency_matrix_backward is not None:
  77. raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
  78. if self.is_symmetric and adjacency_matrix_backward is not None:
  79. raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
  80. if self.is_symmetric and node_type_row == node_type_column and \
  81. not torch.all(_equal(adjacency_matrix,
  82. adjacency_matrix.transpose(0, 1))):
  83. raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
  84. if not self.is_symmetric and node_type_row != node_type_column and \
  85. adjacency_matrix_backward is None:
  86. raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None')
  87. if self.is_symmetric and node_type_row != node_type_column:
  88. adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
  89. self.relation_types.append(RelationType(name,
  90. node_type_row, node_type_column,
  91. adjacency_matrix, adjacency_matrix_backward))
  92. def node_name(self, index):
  93. return self.data.node_types[index].name
  94. def __repr__(self):
  95. s = 'Relation family %s' % self.name
  96. for r in self.relation_types:
  97. s += '\n - %s%s' % (r.name, ' (two-way)' \
  98. if (r.adjacency_matrix is not None \
  99. and r.adjacency_matrix_backward is not None) \
  100. or self.node_type_row == self.node_type_column \
  101. else '%s <- %s' % (self.node_name(self.node_type_row),
  102. self.node_name(self.node_type_column)))
  103. return s
  104. def repr_indented(self):
  105. s = ' - %s' % self.name
  106. for r in self.relation_types:
  107. s += '\n - %s%s' % (r.name, ' (two-way)' \
  108. if (r.adjacency_matrix is not None \
  109. and r.adjacency_matrix_backward is not None) \
  110. or self.node_type_row == self.node_type_column \
  111. else '%s <- %s' % (self.node_name(self.node_type_row),
  112. self.node_name(self.node_type_column)))
  113. return s
  114. class Data(object):
  115. node_types: List[NodeType]
  116. relation_families: List[RelationFamily]
  117. def __init__(self) -> None:
  118. self.node_types = []
  119. self.relation_families = []
  120. def add_node_type(self, name: str, count: int) -> None:
  121. name = str(name)
  122. count = int(count)
  123. if not name:
  124. raise ValueError('You must provide a non-empty node type name')
  125. if count <= 0:
  126. raise ValueError('You must provide a positive node count')
  127. self.node_types.append(NodeType(name, count))
  128. def add_relation_family(self, name: str, node_type_row: int,
  129. node_type_column: int, is_symmetric: bool,
  130. decoder_class: Type = DEDICOMDecoder):
  131. name = str(name)
  132. node_type_row = int(node_type_row)
  133. node_type_column = int(node_type_column)
  134. is_symmetric = bool(is_symmetric)
  135. if node_type_row < 0 or node_type_row >= len(self.node_types):
  136. raise ValueError('node_type_row outside of the valid range of node types')
  137. if node_type_column < 0 or node_type_column >= len(self.node_types):
  138. raise ValueError('node_type_column outside of the valid range of node types')
  139. fam = RelationFamily(self, name, node_type_row, node_type_column,
  140. is_symmetric, decoder_class)
  141. self.relation_families.append(fam)
  142. return fam
  143. def __repr__(self):
  144. n = len(self.node_types)
  145. if n == 0:
  146. return 'Empty Icosagon Data'
  147. s = ''
  148. s += 'Icosagon Data with:\n'
  149. s += '- ' + str(n) + ' node type(s):\n'
  150. for nt in self.node_types:
  151. s += ' - ' + nt.name + '\n'
  152. if len(self.relation_families) == 0:
  153. s += '- No relation families\n'
  154. return s.strip()
  155. s += '- %d relation families:\n' % len(self.relation_families)
  156. for fam in self.relation_families:
  157. s += fam.repr_indented() + '\n'
  158. return s.strip()