IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 lines
8.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: int,
  63. out_channels: int,
  64. adjacency_matrices: List[torch.Tensor],
  65. keep_prob: float = 1.,
  66. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  67. **kwargs) -> None:
  68. super().__init__(**kwargs)
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(adjacency_matrices, list):
  72. raise TypeError('adjacency_matrices must be a list')
  73. if len(adjacency_matrices) == 0:
  74. raise ValueError('adjacency_matrices must not be empty')
  75. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  76. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  77. if not all(m.is_sparse for m in adjacency_matrices):
  78. raise ValueError('adjacency_matrices elements must be sparse')
  79. keep_prob = float(keep_prob)
  80. if not isinstance(activation, types.FunctionType):
  81. raise TypeError('activation must be a function')
  82. self.in_channels = in_channels
  83. self.out_channels = out_channels
  84. self.adjacency_matrices = adjacency_matrices
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.num_row_nodes = len(adjacency_matrices[0])
  88. self.num_relation_types = len(adjacency_matrices)
  89. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  90. self.weights = torch.cat([
  91. init_glorot(in_channels, out_channels) \
  92. for _ in range(self.num_relation_types)
  93. ], dim=1)
  94. def forward(self, x) -> torch.Tensor:
  95. if self.keep_prob < 1.:
  96. x = dropout(x, self.keep_prob)
  97. res = torch.sparse.mm(x, self.weights) \
  98. if x.is_sparse \
  99. else torch.mm(x, self.weights)
  100. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  101. res = torch.cat(res)
  102. res = torch.sparse.mm(self.adjacency_matrices, res) \
  103. if self.adjacency_matrices.is_sparse \
  104. else torch.mm(self.adjacency_matrices, res)
  105. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  106. if self.activation is not None:
  107. res = self.activation(res)
  108. return res
  109. class FastConvLayer(torch.nn.Module):
  110. adjacency_matrix: List[torch.Tensor]
  111. adjacency_matrix_backward: List[torch.Tensor]
  112. weight: List[torch.Tensor]
  113. weight_backward: List[torch.Tensor]
  114. def __init__(self,
  115. input_dim: List[int],
  116. output_dim: List[int],
  117. data: Union[Data, PreparedData],
  118. keep_prob: float = 1.,
  119. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  120. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  121. **kwargs):
  122. super().__init__(**kwargs)
  123. self._check_params(input_dim, output_dim, data, keep_prob,
  124. rel_activation, layer_activation)
  125. self.input_dim = input_dim
  126. self.output_dim = output_dim
  127. self.data = data
  128. self.keep_prob = keep_prob
  129. self.rel_activation = rel_activation
  130. self.layer_activation = layer_activation
  131. self.adjacency_matrix = None
  132. self.adjacency_matrix_backward = None
  133. self.weight = None
  134. self.weight_backward = None
  135. self.build()
  136. def build(self):
  137. self.adjacency_matrix = []
  138. self.adjacency_matrix_backward = []
  139. self.weight = []
  140. self.weight_backward = []
  141. for fam in self.data.relation_families:
  142. adj_mat = [ rel.adjacency_matrix \
  143. for rel in fam.relation_types \
  144. if rel.adjacency_matrix is not None ]
  145. adj_mat_back = [ rel.adjacency_matrix_backward \
  146. for rel in fam.relation_types \
  147. if rel.adjacency_matrix_backward is not None ]
  148. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  149. self.output_dim[fam.node_type_row]) \
  150. for _ in range(len(adj_mat)) ]
  151. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  152. self.output_dim[fam.node_type_row]) \
  153. for _ in range(len(adj_mat_back)) ]
  154. adj_mat = torch.cat(adj_mat) \
  155. if len(adj_mat) > 0 \
  156. else None
  157. adj_mat_back = torch.cat(adj_mat_back) \
  158. if len(adj_mat_back) > 0 \
  159. else None
  160. self.adjacency_matrix.append(adj_mat)
  161. self.adjacency_matrix_backward.append(adj_mat_back)
  162. self.weight.append(weight)
  163. self.weight_backward.append(weight_back)
  164. def forward(self, prev_layer_repr):
  165. for i, fam in enumerate(self.data.relation_families):
  166. repr_row = prev_layer_repr[fam.node_type_row]
  167. repr_column = prev_layer_repr[fam.node_type_column]
  168. adj_mat = self.adjacency_matrix[i]
  169. adj_mat_back = self.adjacency_matrix_backward[i]
  170. if adj_mat is not None:
  171. x = dropout(repr_column, keep_prob=self.keep_prob)
  172. x = torch.sparse.mm(x, self.weight[i]) \
  173. if x.is_sparse \
  174. else torch.mm(x, self.weight[i])
  175. x = torch.sparse.mm(adj_mat, x) \
  176. if adj_mat.is_sparse \
  177. else torch.mm(adj_mat, x)
  178. x = self.rel_activation(x)
  179. x = x.view(len(fam.relation_types), len(repr_row), -1)
  180. if adj_mat_back is not None:
  181. x = dropout(repr_column, keep_prob=self.keep_prob)
  182. x = torch.sparse.mm(x, self.weight_backward[i]) \
  183. if x.is_sparse \
  184. else torch.mm(x, self.weight_backward[i])
  185. x = torch.sparse.mm(adj_mat_back, x) \
  186. if adj_mat_back.is_sparse \
  187. else torch.mm(adj_mat_back, x)
  188. x = self.rel_activation(x)
  189. x = x.view(len(fam.relation_types), len(repr_row), -1)
  190. @staticmethod
  191. def _check_params(input_dim, output_dim, data, keep_prob,
  192. rel_activation, layer_activation):
  193. if not isinstance(input_dim, list):
  194. raise ValueError('input_dim must be a list')
  195. if not output_dim:
  196. raise ValueError('output_dim must be specified')
  197. if not isinstance(output_dim, list):
  198. output_dim = [output_dim] * len(data.node_types)
  199. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  200. raise ValueError('data must be of type Data or PreparedData')