IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

fastconv.py 9.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: int,
  63. out_channels: int,
  64. adjacency_matrices: List[torch.Tensor],
  65. keep_prob: float = 1.,
  66. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  67. **kwargs) -> None:
  68. super().__init__(**kwargs)
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(adjacency_matrices, list):
  72. raise TypeError('adjacency_matrices must be a list')
  73. if len(adjacency_matrices) == 0:
  74. raise ValueError('adjacency_matrices must not be empty')
  75. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  76. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  77. if not all(m.is_sparse for m in adjacency_matrices):
  78. raise ValueError('adjacency_matrices elements must be sparse')
  79. keep_prob = float(keep_prob)
  80. if not isinstance(activation, types.FunctionType):
  81. raise TypeError('activation must be a function')
  82. self.in_channels = in_channels
  83. self.out_channels = out_channels
  84. self.adjacency_matrices = adjacency_matrices
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.num_row_nodes = len(adjacency_matrices[0])
  88. self.num_relation_types = len(adjacency_matrices)
  89. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  90. self.weights = torch.cat([
  91. init_glorot(in_channels, out_channels) \
  92. for _ in range(self.num_relation_types)
  93. ], dim=1)
  94. def forward(self, x) -> torch.Tensor:
  95. if self.keep_prob < 1.:
  96. x = dropout(x, self.keep_prob)
  97. res = torch.sparse.mm(x, self.weights) \
  98. if x.is_sparse \
  99. else torch.mm(x, self.weights)
  100. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  101. res = torch.cat(res)
  102. res = torch.sparse.mm(self.adjacency_matrices, res) \
  103. if self.adjacency_matrices.is_sparse \
  104. else torch.mm(self.adjacency_matrices, res)
  105. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  106. if self.activation is not None:
  107. res = self.activation(res)
  108. return res
  109. class FastConvLayer(torch.nn.Module):
  110. def __init__(self,
  111. input_dim: List[int],
  112. output_dim: List[int],
  113. data: Union[Data, PreparedData],
  114. keep_prob: float = 1.,
  115. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  116. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  117. **kwargs):
  118. super().__init__(**kwargs)
  119. self._check_params(input_dim, output_dim, data, keep_prob,
  120. rel_activation, layer_activation)
  121. self.input_dim = input_dim
  122. self.output_dim = output_dim
  123. self.data = data
  124. self.keep_prob = keep_prob
  125. self.rel_activation = rel_activation
  126. self.layer_activation = layer_activation
  127. self.is_sparse = False
  128. self.next_layer_repr = None
  129. self.build()
  130. def build(self):
  131. self.next_layer_repr = torch.nn.ModuleList([
  132. torch.nn.ModuleList() \
  133. for _ in range(len(self.data.node_types))
  134. ])
  135. for fam in self.data.relation_families:
  136. self.build_family(fam)
  137. def build_family(self, fam) -> None:
  138. if fam.node_type_row == fam.node_type_column:
  139. self.build_fam_one_node_type(fam)
  140. else:
  141. self.build_fam_two_node_types(fam)
  142. def build_fam_one_node_type(self, fam) -> None:
  143. adjacency_matrices = [
  144. r.adjacency_matrix \
  145. for r in fam.relation_types
  146. ]
  147. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  148. self.output_dim[fam.node_type_row],
  149. adjacency_matrices,
  150. self.keep_prob,
  151. self.rel_activation)
  152. conv.input_node_type = fam.node_type_column
  153. self.next_layer_repr[fam.node_type_row].append(conv)
  154. def build_fam_two_node_types(self, fam) -> None:
  155. adjacency_matrices = [
  156. r.adjacency_matrix \
  157. for r in fam.relation_types \
  158. if r.adjacency_matrix is not None
  159. ]
  160. adjacency_matrices_backward = [
  161. r.adjacency_matrix_backward \
  162. for r in fam.relation_types \
  163. if r.adjacency_matrix_backward is not None
  164. ]
  165. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  166. self.output_dim[fam.node_type_row],
  167. adjacency_matrices,
  168. self.keep_prob,
  169. self.rel_activation)
  170. conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
  171. self.output_dim[fam.node_type_column],
  172. adjacency_matrices_backward,
  173. self.keep_prob,
  174. self.rel_activation)
  175. conv.input_node_type = fam.node_type_column
  176. conv_backward.input_node_type = fam.node_type_row
  177. self.next_layer_repr[fam.node_type_row].append(conv)
  178. self.next_layer_repr[fam.node_type_column].append(conv_backward)
  179. def forward(self, prev_layer_repr):
  180. next_layer_repr = [ [] \
  181. for _ in range(len(self.data.node_types)) ]
  182. for output_node_type in range(len(self.data.node_types)):
  183. for conv in self.next_layer_repr[output_node_type]:
  184. rep = conv(prev_layer_repr[conv.input_node_type])
  185. rep = torch.sum(rep, dim=0)
  186. rep = torch.nn.functional.normalize(rep, p=2, dim=1)
  187. next_layer_repr[output_node_type].append(rep)
  188. if len(next_layer_repr[output_node_type]) == 0:
  189. next_layer_repr[output_node_type] = \
  190. torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
  191. else:
  192. next_layer_repr[output_node_type] = \
  193. sum(next_layer_repr[output_node_type])
  194. next_layer_repr[output_node_type] = \
  195. self.layer_activation(next_layer_repr[output_node_type])
  196. return next_layer_repr
  197. @staticmethod
  198. def _check_params(input_dim, output_dim, data, keep_prob,
  199. rel_activation, layer_activation):
  200. if not isinstance(input_dim, list):
  201. raise ValueError('input_dim must be a list')
  202. if not output_dim:
  203. raise ValueError('output_dim must be specified')
  204. if not isinstance(output_dim, list):
  205. output_dim = [output_dim] * len(data.node_types)
  206. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  207. raise ValueError('data must be of type Data or PreparedData')