IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

203 lignes
6.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data
  5. from .trainprep import PreparedData
  6. import torch
  7. from .weights import init_glorot
  8. from .normalize import _sparse_coo_tensor
  9. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  10. if len(matrices) == 0:
  11. raise ValueError('The list of matrices must be non-empty')
  12. if not all(m.is_sparse for m in matrices):
  13. raise ValueError('All matrices must be sparse')
  14. if not all(len(m.shape) == 2 for m in matrices):
  15. raise ValueError('All matrices must be 2D')
  16. indices = []
  17. values = []
  18. row_offset = 0
  19. col_offset = 0
  20. for m in matrices:
  21. ind = m._indices().clone()
  22. ind[0] += row_offset
  23. ind[1] += col_offset
  24. indices.append(ind)
  25. values.append(m._values())
  26. row_offset += m.shape[0]
  27. col_offset += m.shape[1]
  28. indices = torch.cat(indices, dim=1)
  29. values = torch.cat(values)
  30. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  31. def _cat(matrices: List[torch.Tensor]):
  32. if len(matrices) == 0:
  33. raise ValueError('Empty list passed to _cat()')
  34. n = sum(a.is_sparse for a in matrices)
  35. if n != 0 and n != len(matrices):
  36. raise ValueError('All matrices must have the same layout (dense or sparse)')
  37. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  38. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  39. if not matrices[0].is_sparse:
  40. return torch.cat(matrices)
  41. total_rows = sum(a.shape[0] for a in matrices)
  42. indices = []
  43. values = []
  44. row_offset = 0
  45. for a in matrices:
  46. ind = a._indices().clone()
  47. val = a._values()
  48. ind[0] += row_offset
  49. ind = ind.transpose(0, 1)
  50. indices.append(ind)
  51. values.append(val)
  52. row_offset += a.shape[0]
  53. indices = torch.cat(indices).transpose(0, 1)
  54. values = torch.cat(values)
  55. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  56. return res
  57. class FastGraphConv(torch.nn.Module):
  58. def __init__(self,
  59. in_channels: int,
  60. out_channels: int,
  61. adjacency_matrix: List[torch.Tensor],
  62. **kwargs):
  63. self.in_channels = in_channels
  64. self.out_channels = out_channels
  65. self.weight = torch.cat([
  66. init_glorot(in_channels, out_channels) \
  67. for _ in adjacency_matrix
  68. ], dim=1)
  69. self.adjacency_matrix = _cat(adjacency_matrix)
  70. def forward(self, x):
  71. x = torch.sparse.mm(x, self.weight) \
  72. if x.is_sparse \
  73. else torch.mm(x, self.weight)
  74. x = torch.sparse.mm(self.adjacency_matrix, x) \
  75. if self.adjacency_matrix.is_sparse \
  76. else torch.mm(self.adjacency_matrix, x)
  77. return x
  78. class FastConvLayer(torch.nn.Module):
  79. adjacency_matrix: List[torch.Tensor]
  80. adjacency_matrix_backward: List[torch.Tensor]
  81. weight: List[torch.Tensor]
  82. weight_backward: List[torch.Tensor]
  83. def __init__(self,
  84. input_dim: List[int],
  85. output_dim: List[int],
  86. data: Union[Data, PreparedData],
  87. keep_prob: float = 1.,
  88. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  89. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  90. **kwargs):
  91. super().__init__(**kwargs)
  92. self._check_params(input_dim, output_dim, data, keep_prob,
  93. rel_activation, layer_activation)
  94. self.input_dim = input_dim
  95. self.output_dim = output_dim
  96. self.data = data
  97. self.keep_prob = keep_prob
  98. self.rel_activation = rel_activation
  99. self.layer_activation = layer_activation
  100. self.adjacency_matrix = None
  101. self.adjacency_matrix_backward = None
  102. self.weight = None
  103. self.weight_backward = None
  104. self.build()
  105. def build(self):
  106. self.adjacency_matrix = []
  107. self.adjacency_matrix_backward = []
  108. self.weight = []
  109. self.weight_backward = []
  110. for fam in self.data.relation_families:
  111. adj_mat = [ rel.adjacency_matrix \
  112. for rel in fam.relation_types \
  113. if rel.adjacency_matrix is not None ]
  114. adj_mat_back = [ rel.adjacency_matrix_backward \
  115. for rel in fam.relation_types \
  116. if rel.adjacency_matrix_backward is not None ]
  117. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  118. self.output_dim[fam.node_type_row]) \
  119. for _ in range(len(adj_mat)) ]
  120. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  121. self.output_dim[fam.node_type_row]) \
  122. for _ in range(len(adj_mat_back)) ]
  123. adj_mat = torch.cat(adj_mat) \
  124. if len(adj_mat) > 0 \
  125. else None
  126. adj_mat_back = torch.cat(adj_mat_back) \
  127. if len(adj_mat_back) > 0 \
  128. else None
  129. self.adjacency_matrix.append(adj_mat)
  130. self.adjacency_matrix_backward.append(adj_mat_back)
  131. self.weight.append(weight)
  132. self.weight_back.append(weight_back)
  133. def forward(self, prev_layer_repr):
  134. for i, fam in enumerate(self.data.relation_families):
  135. repr_row = prev_layer_repr[fam.node_type_row]
  136. repr_column = prev_layer_repr[fam.node_type_column]
  137. adj_mat = self.adjacency_matrix[i]
  138. adj_mat_back = self.adjacency_matrix_backward[i]
  139. if adj_mat is not None:
  140. x = dropout(repr_column, keep_prob=self.keep_prob)
  141. x = torch.sparse.mm(x, self.weight[i]) \
  142. if x.is_sparse \
  143. else torch.mm(x, self.weight[i])
  144. x = torch.sparse.mm(adj_mat, repr_row) \
  145. if adj_mat.is_sparse \
  146. else torch.mm(adj_mat, repr_row)
  147. x = self.rel_activation(x)
  148. x = x.view(len(fam.relation_types), len(repr_row), -1)
  149. if adj_mat_back is not None:
  150. x = torch.sparse.mm(adj_mat_back, repr_row) \
  151. if adj_mat_back.is_sparse \
  152. else torch.mm(adj_mat_back, repr_row)
  153. @staticmethod
  154. def _check_params(input_dim, output_dim, data, keep_prob,
  155. rel_activation, layer_activation):
  156. if not isinstance(input_dim, list):
  157. raise ValueError('input_dim must be a list')
  158. if not output_dim:
  159. raise ValueError('output_dim must be specified')
  160. if not isinstance(output_dim, list):
  161. output_dim = [output_dim] * len(data.node_types)
  162. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  163. raise ValueError('data must be of type Data or PreparedData')