IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

245 Zeilen
8.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: int,
  63. out_channels: int,
  64. adjacency_matrices: List[torch.Tensor],
  65. keep_prob: float = 1.,
  66. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  67. **kwargs) -> None:
  68. super().__init__(**kwargs)
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(adjacency_matrices, list):
  72. raise TypeError('adjacency_matrices must be a list')
  73. if len(adjacency_matrices) == 0:
  74. raise ValueError('adjacency_matrices must not be empty')
  75. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  76. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  77. if not all(m.is_sparse for m in adjacency_matrices):
  78. raise ValueError('adjacency_matrices elements must be sparse')
  79. keep_prob = float(keep_prob)
  80. if not isinstance(activation, types.FunctionType):
  81. raise TypeError('activation must be a function')
  82. self.in_channels = in_channels
  83. self.out_channels = out_channels
  84. self.adjacency_matrices = adjacency_matrices
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.num_row_nodes = len(adjacency_matrices[0])
  88. self.num_relation_types = len(adjacency_matrices)
  89. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  90. self.weights = torch.cat([
  91. init_glorot(in_channels, out_channels) \
  92. for _ in range(self.num_relation_types)
  93. ], dim=1)
  94. def forward(self, x) -> torch.Tensor:
  95. if self.keep_prob < 1.:
  96. x = dropout(x, self.keep_prob)
  97. res = torch.sparse.mm(x, self.weights) \
  98. if x.is_sparse \
  99. else torch.mm(x, self.weights)
  100. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  101. res = torch.cat(res)
  102. res = torch.sparse.mm(self.adjacency_matrices, res) \
  103. if self.adjacency_matrices.is_sparse \
  104. else torch.mm(self.adjacency_matrices, res)
  105. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  106. if self.activation is not None:
  107. res = self.activation(res)
  108. return res
  109. class FastConvLayer(torch.nn.Module):
  110. adjacency_matrix: List[torch.Tensor]
  111. adjacency_matrix_backward: List[torch.Tensor]
  112. weight: List[torch.Tensor]
  113. weight_backward: List[torch.Tensor]
  114. def __init__(self,
  115. input_dim: List[int],
  116. output_dim: List[int],
  117. data: Union[Data, PreparedData],
  118. keep_prob: float = 1.,
  119. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  120. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  121. **kwargs):
  122. super().__init__(**kwargs)
  123. self._check_params(input_dim, output_dim, data, keep_prob,
  124. rel_activation, layer_activation)
  125. self.input_dim = input_dim
  126. self.output_dim = output_dim
  127. self.data = data
  128. self.keep_prob = keep_prob
  129. self.rel_activation = rel_activation
  130. self.layer_activation = layer_activation
  131. self.adjacency_matrix = None
  132. self.adjacency_matrix_backward = None
  133. self.weight = None
  134. self.weight_backward = None
  135. self.build()
  136. def build(self):
  137. self.adjacency_matrix = []
  138. self.adjacency_matrix_backward = []
  139. self.weight = []
  140. self.weight_backward = []
  141. for fam in self.data.relation_families:
  142. adj_mat = [ rel.adjacency_matrix \
  143. for rel in fam.relation_types \
  144. if rel.adjacency_matrix is not None ]
  145. adj_mat_back = [ rel.adjacency_matrix_backward \
  146. for rel in fam.relation_types \
  147. if rel.adjacency_matrix_backward is not None ]
  148. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  149. self.output_dim[fam.node_type_row]) \
  150. for _ in range(len(adj_mat)) ]
  151. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  152. self.output_dim[fam.node_type_row]) \
  153. for _ in range(len(adj_mat_back)) ]
  154. adj_mat = torch.cat(adj_mat) \
  155. if len(adj_mat) > 0 \
  156. else None
  157. adj_mat_back = torch.cat(adj_mat_back) \
  158. if len(adj_mat_back) > 0 \
  159. else None
  160. self.adjacency_matrix.append(adj_mat)
  161. self.adjacency_matrix_backward.append(adj_mat_back)
  162. self.weight.append(weight)
  163. self.weight_backward.append(weight_back)
  164. def forward(self, prev_layer_repr):
  165. for i, fam in enumerate(self.data.relation_families):
  166. repr_row = prev_layer_repr[fam.node_type_row]
  167. repr_column = prev_layer_repr[fam.node_type_column]
  168. adj_mat = self.adjacency_matrix[i]
  169. adj_mat_back = self.adjacency_matrix_backward[i]
  170. if adj_mat is not None:
  171. x = dropout(repr_column, keep_prob=self.keep_prob)
  172. x = torch.sparse.mm(x, self.weight[i]) \
  173. if x.is_sparse \
  174. else torch.mm(x, self.weight[i])
  175. x = torch.sparse.mm(adj_mat, x) \
  176. if adj_mat.is_sparse \
  177. else torch.mm(adj_mat, x)
  178. x = self.rel_activation(x)
  179. x = x.view(len(fam.relation_types), len(repr_row), -1)
  180. if adj_mat_back is not None:
  181. x = dropout(repr_column, keep_prob=self.keep_prob)
  182. x = torch.sparse.mm(x, self.weight_backward[i]) \
  183. if x.is_sparse \
  184. else torch.mm(x, self.weight_backward[i])
  185. x = torch.sparse.mm(adj_mat_back, x) \
  186. if adj_mat_back.is_sparse \
  187. else torch.mm(adj_mat_back, x)
  188. x = self.rel_activation(x)
  189. x = x.view(len(fam.relation_types), len(repr_row), -1)
  190. @staticmethod
  191. def _check_params(input_dim, output_dim, data, keep_prob,
  192. rel_activation, layer_activation):
  193. if not isinstance(input_dim, list):
  194. raise ValueError('input_dim must be a list')
  195. if not output_dim:
  196. raise ValueError('output_dim must be specified')
  197. if not isinstance(output_dim, list):
  198. output_dim = [output_dim] * len(data.node_types)
  199. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  200. raise ValueError('data must be of type Data or PreparedData')