IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

258 lines
9.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: List[int],
  63. out_channels: List[int],
  64. data: Union[Data, PreparedData],
  65. relation_family: Union[RelationFamily, PreparedRelationFamily],
  66. keep_prob: float = 1.,
  67. acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  68. **kwargs) -> None:
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  72. raise TypeError('data must be an instance of Data or PreparedData')
  73. if not isinstance(relation_family, RelationFamily) and \
  74. not isinstance(relation_family, PreparedRelationFamily):
  75. raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
  76. keep_prob = float(keep_prob)
  77. if not isinstance(activation, types.FunctionType):
  78. raise TypeError('activation must be a function')
  79. n_nodes_row = data.node_types[relation_family.node_type_row].count
  80. n_nodes_column = data.node_types[relation_family.node_type_column].count
  81. self.in_channels = in_channels
  82. self.out_channels = out_channels
  83. self.data = data
  84. self.relation_family = relation_family
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.weight = torch.cat([
  88. init_glorot(in_channels, out_channels) \
  89. for _ in range(len(relation_family.relation_types))
  90. ], dim=1)
  91. self.weight_backward = torch.cat([
  92. init_glorot(in_channels, out_channels) \
  93. for _ in range(len(relation_family.relation_types))
  94. ], dim=1)
  95. self.adjacency_matrix = _sparse_diag_cat([
  96. rel.adjacency_matrix \
  97. if rel.adjacency_matrix is not None \
  98. else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
  99. for rel in relation_family.relation_types ])
  100. self.adjacency_matrix_backward = _sparse_diag_cat([
  101. rel.adjacency_matrix_backward \
  102. if rel.adjacency_matrix_backward is not None \
  103. else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
  104. for rel in relation_family.relation_types ])
  105. def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
  106. repr_row = prev_layer_repr[self.relation_family.node_type_row]
  107. repr_column = prev_layer_repr[self.relation_family.node_type_column]
  108. new_repr_row = torch.sparse.mm(repr_column, self.weight) \
  109. if repr_column.is_sparse \
  110. else torch.mm(repr_column, self.weight)
  111. new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
  112. if self.adjacency_matrix.is_sparse \
  113. else torch.mm(self.adjacency_matrix, new_repr_row)
  114. new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
  115. len(repr_row), self.out_channels)
  116. new_repr_column = torch.sparse.mm(repr_row, self.weight) \
  117. if repr_row.is_sparse \
  118. else torch.mm(repr_row, self.weight)
  119. new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
  120. if self.adjacency_matrix_backward.is_sparse \
  121. else torch.mm(self.adjacency_matrix_backward, new_repr_column)
  122. new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
  123. len(repr_column), self.out_channels)
  124. return (new_repr_row, new_repr_column)
  125. class FastConvLayer(torch.nn.Module):
  126. adjacency_matrix: List[torch.Tensor]
  127. adjacency_matrix_backward: List[torch.Tensor]
  128. weight: List[torch.Tensor]
  129. weight_backward: List[torch.Tensor]
  130. def __init__(self,
  131. input_dim: List[int],
  132. output_dim: List[int],
  133. data: Union[Data, PreparedData],
  134. keep_prob: float = 1.,
  135. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  136. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  137. **kwargs):
  138. super().__init__(**kwargs)
  139. self._check_params(input_dim, output_dim, data, keep_prob,
  140. rel_activation, layer_activation)
  141. self.input_dim = input_dim
  142. self.output_dim = output_dim
  143. self.data = data
  144. self.keep_prob = keep_prob
  145. self.rel_activation = rel_activation
  146. self.layer_activation = layer_activation
  147. self.adjacency_matrix = None
  148. self.adjacency_matrix_backward = None
  149. self.weight = None
  150. self.weight_backward = None
  151. self.build()
  152. def build(self):
  153. self.adjacency_matrix = []
  154. self.adjacency_matrix_backward = []
  155. self.weight = []
  156. self.weight_backward = []
  157. for fam in self.data.relation_families:
  158. adj_mat = [ rel.adjacency_matrix \
  159. for rel in fam.relation_types \
  160. if rel.adjacency_matrix is not None ]
  161. adj_mat_back = [ rel.adjacency_matrix_backward \
  162. for rel in fam.relation_types \
  163. if rel.adjacency_matrix_backward is not None ]
  164. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  165. self.output_dim[fam.node_type_row]) \
  166. for _ in range(len(adj_mat)) ]
  167. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  168. self.output_dim[fam.node_type_row]) \
  169. for _ in range(len(adj_mat_back)) ]
  170. adj_mat = torch.cat(adj_mat) \
  171. if len(adj_mat) > 0 \
  172. else None
  173. adj_mat_back = torch.cat(adj_mat_back) \
  174. if len(adj_mat_back) > 0 \
  175. else None
  176. self.adjacency_matrix.append(adj_mat)
  177. self.adjacency_matrix_backward.append(adj_mat_back)
  178. self.weight.append(weight)
  179. self.weight_back.append(weight_back)
  180. def forward(self, prev_layer_repr):
  181. for i, fam in enumerate(self.data.relation_families):
  182. repr_row = prev_layer_repr[fam.node_type_row]
  183. repr_column = prev_layer_repr[fam.node_type_column]
  184. adj_mat = self.adjacency_matrix[i]
  185. adj_mat_back = self.adjacency_matrix_backward[i]
  186. if adj_mat is not None:
  187. x = dropout(repr_column, keep_prob=self.keep_prob)
  188. x = torch.sparse.mm(x, self.weight[i]) \
  189. if x.is_sparse \
  190. else torch.mm(x, self.weight[i])
  191. x = torch.sparse.mm(adj_mat, repr_row) \
  192. if adj_mat.is_sparse \
  193. else torch.mm(adj_mat, repr_row)
  194. x = self.rel_activation(x)
  195. x = x.view(len(fam.relation_types), len(repr_row), -1)
  196. if adj_mat_back is not None:
  197. x = torch.sparse.mm(adj_mat_back, repr_row) \
  198. if adj_mat_back.is_sparse \
  199. else torch.mm(adj_mat_back, repr_row)
  200. @staticmethod
  201. def _check_params(input_dim, output_dim, data, keep_prob,
  202. rel_activation, layer_activation):
  203. if not isinstance(input_dim, list):
  204. raise ValueError('input_dim must be a list')
  205. if not output_dim:
  206. raise ValueError('output_dim must be specified')
  207. if not isinstance(output_dim, list):
  208. output_dim = [output_dim] * len(data.node_types)
  209. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  210. raise ValueError('data must be of type Data or PreparedData')