IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.9KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data
  5. from .trainprep import PreparedData
  6. import torch
  7. from .weights import init_glorot
  8. from .normalize import _sparse_coo_tensor
  9. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  10. if len(matrices) == 0:
  11. raise ValueError('The list of matrices must be non-empty')
  12. if not all(m.is_sparse for m in matrices):
  13. raise ValueError('All matrices must be sparse')
  14. if not all(len(m.shape) == 2 for m in matrices):
  15. raise ValueError('All matrices must be 2D')
  16. indices = []
  17. values = []
  18. row_offset = 0
  19. col_offset = 0
  20. for m in matrices:
  21. ind = m._indices().clone()
  22. ind[0] += row_offset
  23. ind[1] += col_offset
  24. indices.append(ind)
  25. values.append(m._values())
  26. row_offset += m.shape[0]
  27. col_offset += m.shape[1]
  28. indices = torch.cat(indices, dim=1)
  29. values = torch.cat(values)
  30. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  31. def _cat(matrices: List[torch.Tensor]):
  32. if len(matrices) == 0:
  33. raise ValueError('Empty list passed to _cat()')
  34. n = sum(a.is_sparse for a in matrices)
  35. if n != 0 and n != len(matrices):
  36. raise ValueError('All matrices must have the same layout (dense or sparse)')
  37. if not all(a.shape[1:] == matrices[0].shape[1:]):
  38. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  39. if not matrices[0].is_sparse:
  40. return torch.cat(matrices)
  41. total_rows = sum(a.shape[0] for a in matrices)
  42. indices = []
  43. values = []
  44. row_offset = 0
  45. for a in matrices:
  46. ind = a._indices().clone()
  47. val = a._values()
  48. ind[0] += row_offset
  49. ind = ind.transpose(0, 1)
  50. indices.append(ind)
  51. values.append(val)
  52. row_offset += a.shape[0]
  53. indices = torch.cat(indices).transpose(0, 1)
  54. values = torch.cat(values)
  55. res = _sparse_coo_tensor(indices, values)
  56. return res
  57. class FastGraphConv(torch.nn.Module):
  58. def __init__(self,
  59. in_channels: int,
  60. out_channels: int,
  61. adjacency_matrix: List[torch.Tensor],
  62. **kwargs):
  63. self.in_channels = in_channels
  64. self.out_channels = out_channels
  65. self.weight = torch.cat([
  66. init_glorot(in_channels, out_channels) \
  67. for _ in adjacency_matrix
  68. ], dim=1)
  69. self.adjacency_matrix = _cat(adjacency_matrix)
  70. def forward(self, x):
  71. x = torch.sparse.mm(x, self.weight) \
  72. if x.is_sparse \
  73. else torch.mm(x, self.weight)
  74. x = torch.sparse.mm(self.adjacency_matrix, x) \
  75. if self.adjacency_matrix.is_sparse \
  76. else torch.mm(self.adjacency_matrix, x)
  77. return x
  78. class FastConvLayer(torch.nn.Module):
  79. adjacency_matrix: List[torch.Tensor]
  80. adjacency_matrix_backward: List[torch.Tensor]
  81. weight: List[torch.Tensor]
  82. weight_backward: List[torch.Tensor]
  83. def __init__(self,
  84. input_dim: List[int],
  85. output_dim: List[int],
  86. data: Union[Data, PreparedData],
  87. keep_prob: float = 1.,
  88. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  89. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  90. **kwargs):
  91. super().__init__(**kwargs)
  92. self._check_params(input_dim, output_dim, data, keep_prob,
  93. rel_activation, layer_activation)
  94. self.input_dim = input_dim
  95. self.output_dim = output_dim
  96. self.data = data
  97. self.keep_prob = keep_prob
  98. self.rel_activation = rel_activation
  99. self.layer_activation = layer_activation
  100. self.adjacency_matrix = None
  101. self.adjacency_matrix_backward = None
  102. self.weight = None
  103. self.weight_backward = None
  104. self.build()
  105. def build(self):
  106. self.adjacency_matrix = []
  107. self.adjacency_matrix_backward = []
  108. self.weight = []
  109. self.weight_backward = []
  110. for fam in self.data.relation_families:
  111. adj_mat = [ rel.adjacency_matrix \
  112. for rel in fam.relation_types \
  113. if rel.adjacency_matrix is not None ]
  114. adj_mat_back = [ rel.adjacency_matrix_backward \
  115. for rel in fam.relation_types \
  116. if rel.adjacency_matrix_backward is not None ]
  117. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  118. self.output_dim[fam.node_type_row]) \
  119. for _ in range(len(adj_mat)) ]
  120. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  121. self.output_dim[fam.node_type_row]) \
  122. for _ in range(len(adj_mat_back)) ]
  123. adj_mat = torch.cat(adj_mat) \
  124. if len(adj_mat) > 0 \
  125. else None
  126. adj_mat_back = torch.cat(adj_mat_back) \
  127. if len(adj_mat_back) > 0 \
  128. else None
  129. self.adjacency_matrix.append(adj_mat)
  130. self.adjacency_matrix_backward.append(adj_mat_back)
  131. self.weight.append(weight)
  132. self.weight_back.append(weight_back)
  133. def forward(self, prev_layer_repr):
  134. for i, fam in enumerate(self.data.relation_families):
  135. repr_row = prev_layer_repr[fam.node_type_row]
  136. repr_column = prev_layer_repr[fam.node_type_column]
  137. adj_mat = self.adjacency_matrix[i]
  138. adj_mat_back = self.adjacency_matrix_backward[i]
  139. if adj_mat is not None:
  140. x = dropout(repr_column, keep_prob=self.keep_prob)
  141. x = torch.sparse.mm(x, self.weight[i]) \
  142. if x.is_sparse \
  143. else torch.mm(x, self.weight[i])
  144. x = torch.sparse.mm(adj_mat, repr_row) \
  145. if adj_mat.is_sparse \
  146. else torch.mm(adj_mat, repr_row)
  147. x = self.rel_activation(x)
  148. x = x.view(len(fam.relation_types), len(repr_row), -1)
  149. if adj_mat_back is not None:
  150. x = torch.sparse.mm(adj_mat_back, repr_row) \
  151. if adj_mat_back.is_sparse \
  152. else torch.mm(adj_mat_back, repr_row)
  153. @staticmethod
  154. def _check_params(input_dim, output_dim, data, keep_prob,
  155. rel_activation, layer_activation):
  156. if not isinstance(input_dim, list):
  157. raise ValueError('input_dim must be a list')
  158. if not output_dim:
  159. raise ValueError('output_dim must be specified')
  160. if not isinstance(output_dim, list):
  161. output_dim = [output_dim] * len(data.node_types)
  162. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  163. raise ValueError('data must be of type Data or PreparedData')