IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 lines
3.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from ..dropout import dropout_sparse
  7. from ..weights import init_glorot
  8. from typing import List, Callable
  9. class SparseGraphConv(torch.nn.Module):
  10. """Convolution layer for sparse inputs."""
  11. def __init__(self, in_channels: int, out_channels: int,
  12. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  13. super().__init__(**kwargs)
  14. self.in_channels = in_channels
  15. self.out_channels = out_channels
  16. self.weight = init_glorot(in_channels, out_channels)
  17. self.adjacency_matrix = adjacency_matrix
  18. def forward(self, x: torch.Tensor) -> torch.Tensor:
  19. x = torch.sparse.mm(x, self.weight)
  20. x = torch.sparse.mm(self.adjacency_matrix, x)
  21. return x
  22. class SparseDropoutGraphConvActivation(torch.nn.Module):
  23. def __init__(self, input_dim: int, output_dim: int,
  24. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  25. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  26. **kwargs) -> None:
  27. super().__init__(**kwargs)
  28. self.input_dim = input_dim
  29. self.output_dim = output_dim
  30. self.adjacency_matrix = adjacency_matrix
  31. self.keep_prob = keep_prob
  32. self.activation = activation
  33. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  34. def forward(self, x: torch.Tensor) -> torch.Tensor:
  35. x = dropout_sparse(x, self.keep_prob)
  36. x = self.sparse_graph_conv(x)
  37. x = self.activation(x)
  38. return x
  39. class SparseMultiDGCA(torch.nn.Module):
  40. def __init__(self, input_dim: List[int], output_dim: int,
  41. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  42. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  43. **kwargs) -> None:
  44. super().__init__(**kwargs)
  45. self.input_dim = input_dim
  46. self.output_dim = output_dim
  47. self.adjacency_matrices = adjacency_matrices
  48. self.keep_prob = keep_prob
  49. self.activation = activation
  50. self.sparse_dgca = None
  51. self.build()
  52. def build(self):
  53. if len(self.input_dim) != len(self.adjacency_matrices):
  54. raise ValueError('input_dim must have the same length as adjacency_matrices')
  55. self.sparse_dgca = []
  56. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  57. self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  58. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  59. if not isinstance(x, list):
  60. raise ValueError('x must be a list of tensors')
  61. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  62. for i, f in enumerate(self.sparse_dgca):
  63. out += f(x[i])
  64. out = torch.nn.functional.normalize(out, p=2, dim=1)
  65. return out