IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

fastconv.py 9.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  13. if len(matrices) == 0:
  14. raise ValueError('The list of matrices must be non-empty')
  15. if not all(m.is_sparse for m in matrices):
  16. raise ValueError('All matrices must be sparse')
  17. if not all(len(m.shape) == 2 for m in matrices):
  18. raise ValueError('All matrices must be 2D')
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. col_offset = 0
  23. for m in matrices:
  24. ind = m._indices().clone()
  25. ind[0] += row_offset
  26. ind[1] += col_offset
  27. indices.append(ind)
  28. values.append(m._values())
  29. row_offset += m.shape[0]
  30. col_offset += m.shape[1]
  31. indices = torch.cat(indices, dim=1)
  32. values = torch.cat(values)
  33. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  34. def _cat(matrices: List[torch.Tensor]):
  35. if len(matrices) == 0:
  36. raise ValueError('Empty list passed to _cat()')
  37. n = sum(a.is_sparse for a in matrices)
  38. if n != 0 and n != len(matrices):
  39. raise ValueError('All matrices must have the same layout (dense or sparse)')
  40. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  41. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  42. if not matrices[0].is_sparse:
  43. return torch.cat(matrices)
  44. total_rows = sum(a.shape[0] for a in matrices)
  45. indices = []
  46. values = []
  47. row_offset = 0
  48. for a in matrices:
  49. ind = a._indices().clone()
  50. val = a._values()
  51. ind[0] += row_offset
  52. ind = ind.transpose(0, 1)
  53. indices.append(ind)
  54. values.append(val)
  55. row_offset += a.shape[0]
  56. indices = torch.cat(indices).transpose(0, 1)
  57. values = torch.cat(values)
  58. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  59. return res
  60. class FastGraphConv(torch.nn.Module):
  61. def __init__(self,
  62. in_channels: int,
  63. out_channels: int,
  64. adjacency_matrices: List[torch.Tensor],
  65. keep_prob: float = 1.,
  66. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  67. **kwargs) -> None:
  68. super().__init__(**kwargs)
  69. in_channels = int(in_channels)
  70. out_channels = int(out_channels)
  71. if not isinstance(adjacency_matrices, list):
  72. raise TypeError('adjacency_matrices must be a list')
  73. if len(adjacency_matrices) == 0:
  74. raise ValueError('adjacency_matrices must not be empty')
  75. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  76. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  77. if not all(m.is_sparse for m in adjacency_matrices):
  78. raise ValueError('adjacency_matrices elements must be sparse')
  79. keep_prob = float(keep_prob)
  80. if not isinstance(activation, types.FunctionType):
  81. raise TypeError('activation must be a function')
  82. self.in_channels = in_channels
  83. self.out_channels = out_channels
  84. self.adjacency_matrices = adjacency_matrices
  85. self.keep_prob = keep_prob
  86. self.activation = activation
  87. self.num_row_nodes = len(adjacency_matrices[0])
  88. self.num_relation_types = len(adjacency_matrices)
  89. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  90. self.weights = torch.cat([
  91. init_glorot(in_channels, out_channels) \
  92. for _ in range(self.num_relation_types)
  93. ], dim=1)
  94. def forward(self, x) -> torch.Tensor:
  95. if self.keep_prob < 1.:
  96. x = dropout(x, self.keep_prob)
  97. res = torch.sparse.mm(x, self.weights) \
  98. if x.is_sparse \
  99. else torch.mm(x, self.weights)
  100. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  101. res = torch.cat(res)
  102. res = torch.sparse.mm(self.adjacency_matrices, res) \
  103. if self.adjacency_matrices.is_sparse \
  104. else torch.mm(self.adjacency_matrices, res)
  105. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  106. if self.activation is not None:
  107. res = self.activation(res)
  108. return res
  109. class FastConvLayer(torch.nn.Module):
  110. def __init__(self,
  111. input_dim: List[int],
  112. output_dim: List[int],
  113. data: Union[Data, PreparedData],
  114. keep_prob: float = 1.,
  115. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  116. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  117. **kwargs):
  118. super().__init__(**kwargs)
  119. self._check_params(input_dim, output_dim, data, keep_prob,
  120. rel_activation, layer_activation)
  121. self.input_dim = input_dim
  122. self.output_dim = output_dim
  123. self.data = data
  124. self.keep_prob = keep_prob
  125. self.rel_activation = rel_activation
  126. self.layer_activation = layer_activation
  127. self.is_sparse = False
  128. self.next_layer_repr = None
  129. self.build()
  130. def build(self):
  131. self.next_layer_repr = torch.nn.ModuleList([
  132. torch.nn.ModuleList() \
  133. for _ in range(len(self.data.node_types))
  134. ])
  135. for fam in self.data.relation_families:
  136. self.build_family(fam)
  137. def build_family(self, fam) -> None:
  138. if fam.node_type_row == fam.node_type_column:
  139. self.build_fam_one_node_type(fam)
  140. else:
  141. self.build_fam_two_node_types(fam)
  142. def build_fam_one_node_type(self, fam) -> None:
  143. adjacency_matrices = [
  144. r.adjacency_matrix \
  145. for r in fam.relation_types
  146. ]
  147. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  148. self.output_dim[fam.node_type_row],
  149. adjacency_matrices,
  150. self.keep_prob,
  151. self.rel_activation)
  152. conv.input_node_type = fam.node_type_column
  153. self.next_layer_repr[fam.node_type_row].append(conv)
  154. def build_fam_two_node_types(self, fam) -> None:
  155. adjacency_matrices = [
  156. r.adjacency_matrix \
  157. for r in fam.relation_types \
  158. if r.adjacency_matrix is not None
  159. ]
  160. adjacency_matrices_backward = [
  161. r.adjacency_matrix_backward \
  162. for r in fam.relation_types \
  163. if r.adjacency_matrix_backward is not None
  164. ]
  165. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  166. self.output_dim[fam.node_type_row],
  167. adjacency_matrices,
  168. self.keep_prob,
  169. self.rel_activation)
  170. conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
  171. self.output_dim[fam.node_type_column],
  172. adjacency_matrices_backward,
  173. self.keep_prob,
  174. self.rel_activation)
  175. conv.input_node_type = fam.node_type_column
  176. conv_backward.input_node_type = fam.node_type_row
  177. self.next_layer_repr[fam.node_type_row].append(conv)
  178. self.next_layer_repr[fam.node_type_column].append(conv_backward)
  179. def forward(self, prev_layer_repr):
  180. next_layer_repr = [ [] \
  181. for _ in range(len(self.data.node_types)) ]
  182. for output_node_type in range(len(self.data.node_types)):
  183. for conv in self.next_layer_repr[output_node_type]:
  184. rep = conv(prev_layer_repr[conv.input_node_type])
  185. rep = torch.sum(rep, dim=0)
  186. rep = torch.nn.functional.normalize(rep, p=2, dim=1)
  187. next_layer_repr[output_node_type].append(rep)
  188. if len(next_layer_repr[output_node_type]) == 0:
  189. next_layer_repr[output_node_type] = \
  190. torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
  191. else:
  192. next_layer_repr[output_node_type] = \
  193. sum(next_layer_repr[output_node_type])
  194. next_layer_repr[output_node_type] = \
  195. self.layer_activation(next_layer_repr[output_node_type])
  196. return next_layer_repr
  197. @staticmethod
  198. def _check_params(input_dim, output_dim, data, keep_prob,
  199. rel_activation, layer_activation):
  200. if not isinstance(input_dim, list):
  201. raise ValueError('input_dim must be a list')
  202. if not output_dim:
  203. raise ValueError('output_dim must be specified')
  204. if not isinstance(output_dim, list):
  205. output_dim = [output_dim] * len(data.node_types)
  206. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  207. raise ValueError('data must be of type Data or PreparedData')