IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
3.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from ..dropout import dropout_sparse, \
  7. dropout
  8. from ..weights import init_glorot
  9. from typing import List, Callable
  10. class GraphConv(torch.nn.Module):
  11. """Convolution layer for sparse AND dense inputs."""
  12. def __init__(self, in_channels: int, out_channels: int,
  13. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  14. super().__init__(**kwargs)
  15. self.in_channels = in_channels
  16. self.out_channels = out_channels
  17. self.weight = init_glorot(in_channels, out_channels)
  18. self.adjacency_matrix = adjacency_matrix
  19. def forward(self, x: torch.Tensor) -> torch.Tensor:
  20. x = torch.sparse.mm(x, self.weight) \
  21. if x.is_sparse \
  22. else torch.mm(x, self.weight)
  23. x = torch.sparse.mm(self.adjacency_matrix, x) \
  24. if self.adjacency_matrix.is_sparse \
  25. else torch.mm(self.adjacency_matrix, x)
  26. return x
  27. class DropoutGraphConvActivation(torch.nn.Module):
  28. def __init__(self, input_dim: int, output_dim: int,
  29. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  30. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  31. **kwargs) -> None:
  32. super().__init__(**kwargs)
  33. self.input_dim = input_dim
  34. self.output_dim = output_dim
  35. self.adjacency_matrix = adjacency_matrix
  36. self.keep_prob = keep_prob
  37. self.activation = activation
  38. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  39. def forward(self, x: torch.Tensor) -> torch.Tensor:
  40. x = dropout_sparse(x, self.keep_prob) \
  41. if x.is_sparse \
  42. else dropout(x, self.keep_prob)
  43. x = self.graph_conv(x)
  44. x = self.activation(x)
  45. return x
  46. class MultiDGCA(torch.nn.Module):
  47. def __init__(self, input_dim: List[int], output_dim: int,
  48. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  49. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  50. **kwargs) -> None:
  51. super().__init__(**kwargs)
  52. self.input_dim = input_dim
  53. self.output_dim = output_dim
  54. self.adjacency_matrices = adjacency_matrices
  55. self.keep_prob = keep_prob
  56. self.activation = activation
  57. self.dgca = None
  58. self.build()
  59. def build(self):
  60. if len(self.input_dim) != len(self.adjacency_matrices):
  61. raise ValueError('input_dim must have the same length as adjacency_matrices')
  62. self.dgca = []
  63. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  64. self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  65. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  66. if not isinstance(x, list):
  67. raise ValueError('x must be a list of tensors')
  68. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  69. for i, f in enumerate(self.dgca):
  70. out += f(x[i])
  71. out = torch.nn.functional.normalize(out, p=2, dim=1)
  72. return out