IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

258 wiersze
9.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: List[int],
  63. out_channels: List[int],
  64. data: Union[Data, PreparedData],
  65. relation_family: Union[RelationFamily, PreparedRelationFamily]
  66. keep_prob: float = 1.,
  67. acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  68. **kwargs) -> None:
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  72. raise TypeError('data must be an instance of Data or PreparedData')
  73. if not isinstance(relation_family, RelationFamily) and \
  74. not isinstance(relation_family, PreparedRelationFamily):
  75. raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
  76. keep_prob = float(keep_prob)
  77. if not isinstance(activation, types.FunctionType):
  78. raise TypeError('activation must be a function')
  79. n_nodes_row = data.node_types[relation_family.node_type_row].count
  80. n_nodes_column = data.node_types[relation_family.node_type_column].count
  81. self.in_channels = in_channels
  82. self.out_channels = out_channels
  83. self.data = data
  84. self.relation_family = relation_family
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.weight = torch.cat([
  88. init_glorot(in_channels, out_channels) \
  89. for _ in range(len(relation_family.relation_types))
  90. ], dim=1)
  91. self.weight_backward = torch.cat([
  92. init_glorot(in_channels, out_channels) \
  93. for _ in range(len(relation_family.relation_types))
  94. ], dim=1)
  95. self.adjacency_matrix = _sparse_diag_cat([
  96. rel.adjacency_matrix \
  97. if rel.adjacency_matrix is not None \
  98. else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
  99. for rel in relation_family.relation_types ])
  100. self.adjacency_matrix_backward = _sparse_diag_cat([
  101. rel.adjacency_matrix_backward \
  102. if rel.adjacency_matrix_backward is not None \
  103. else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
  104. for rel in relation_family.relation_types ])
  105. def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
  106. repr_row = prev_layer_repr[self.relation_family.node_type_row]
  107. repr_column = prev_layer_repr[self.relation_family.node_type_column]
  108. new_repr_row = torch.sparse.mm(repr_column, self.weight) \
  109. if repr_column.is_sparse \
  110. else torch.mm(repr_column, self.weight)
  111. new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
  112. if self.adjacency_matrix.is_sparse \
  113. else torch.mm(self.adjacency_matrix, new_repr_row)
  114. new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
  115. len(repr_row), self.out_channels)
  116. new_repr_column = torch.sparse.mm(repr_row, self.weight) \
  117. if repr_row.is_sparse \
  118. else torch.mm(repr_row, self.weight)
  119. new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
  120. if self.adjacency_matrix_backward.is_sparse \
  121. else torch.mm(self.adjacency_matrix_backward, new_repr_column)
  122. new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
  123. len(repr_column), self.out_channels)
  124. return (new_repr_row, new_repr_column)
  125. class FastConvLayer(torch.nn.Module):
  126. adjacency_matrix: List[torch.Tensor]
  127. adjacency_matrix_backward: List[torch.Tensor]
  128. weight: List[torch.Tensor]
  129. weight_backward: List[torch.Tensor]
  130. def __init__(self,
  131. input_dim: List[int],
  132. output_dim: List[int],
  133. data: Union[Data, PreparedData],
  134. keep_prob: float = 1.,
  135. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  136. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  137. **kwargs):
  138. super().__init__(**kwargs)
  139. self._check_params(input_dim, output_dim, data, keep_prob,
  140. rel_activation, layer_activation)
  141. self.input_dim = input_dim
  142. self.output_dim = output_dim
  143. self.data = data
  144. self.keep_prob = keep_prob
  145. self.rel_activation = rel_activation
  146. self.layer_activation = layer_activation
  147. self.adjacency_matrix = None
  148. self.adjacency_matrix_backward = None
  149. self.weight = None
  150. self.weight_backward = None
  151. self.build()
  152. def build(self):
  153. self.adjacency_matrix = []
  154. self.adjacency_matrix_backward = []
  155. self.weight = []
  156. self.weight_backward = []
  157. for fam in self.data.relation_families:
  158. adj_mat = [ rel.adjacency_matrix \
  159. for rel in fam.relation_types \
  160. if rel.adjacency_matrix is not None ]
  161. adj_mat_back = [ rel.adjacency_matrix_backward \
  162. for rel in fam.relation_types \
  163. if rel.adjacency_matrix_backward is not None ]
  164. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  165. self.output_dim[fam.node_type_row]) \
  166. for _ in range(len(adj_mat)) ]
  167. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  168. self.output_dim[fam.node_type_row]) \
  169. for _ in range(len(adj_mat_back)) ]
  170. adj_mat = torch.cat(adj_mat) \
  171. if len(adj_mat) > 0 \
  172. else None
  173. adj_mat_back = torch.cat(adj_mat_back) \
  174. if len(adj_mat_back) > 0 \
  175. else None
  176. self.adjacency_matrix.append(adj_mat)
  177. self.adjacency_matrix_backward.append(adj_mat_back)
  178. self.weight.append(weight)
  179. self.weight_back.append(weight_back)
  180. def forward(self, prev_layer_repr):
  181. for i, fam in enumerate(self.data.relation_families):
  182. repr_row = prev_layer_repr[fam.node_type_row]
  183. repr_column = prev_layer_repr[fam.node_type_column]
  184. adj_mat = self.adjacency_matrix[i]
  185. adj_mat_back = self.adjacency_matrix_backward[i]
  186. if adj_mat is not None:
  187. x = dropout(repr_column, keep_prob=self.keep_prob)
  188. x = torch.sparse.mm(x, self.weight[i]) \
  189. if x.is_sparse \
  190. else torch.mm(x, self.weight[i])
  191. x = torch.sparse.mm(adj_mat, repr_row) \
  192. if adj_mat.is_sparse \
  193. else torch.mm(adj_mat, repr_row)
  194. x = self.rel_activation(x)
  195. x = x.view(len(fam.relation_types), len(repr_row), -1)
  196. if adj_mat_back is not None:
  197. x = torch.sparse.mm(adj_mat_back, repr_row) \
  198. if adj_mat_back.is_sparse \
  199. else torch.mm(adj_mat_back, repr_row)
  200. @staticmethod
  201. def _check_params(input_dim, output_dim, data, keep_prob,
  202. rel_activation, layer_activation):
  203. if not isinstance(input_dim, list):
  204. raise ValueError('input_dim must be a list')
  205. if not output_dim:
  206. raise ValueError('output_dim must be specified')
  207. if not isinstance(output_dim, list):
  208. output_dim = [output_dim] * len(data.node_types)
  209. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  210. raise ValueError('data must be of type Data or PreparedData')