IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
2.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .dropout import dropout
  7. from .weights import init_glorot
  8. from typing import List, Callable
  9. import pdb
  10. class GraphConv(torch.nn.Module):
  11. def __init__(self, in_channels: int, out_channels: int,
  12. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  13. super().__init__(**kwargs)
  14. self.in_channels = in_channels
  15. self.out_channels = out_channels
  16. self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
  17. self.adjacency_matrix = adjacency_matrix
  18. def forward(self, x: torch.Tensor) -> torch.Tensor:
  19. x = torch.sparse.mm(x, self.weight) \
  20. if x.is_sparse \
  21. else torch.mm(x, self.weight)
  22. x = torch.sparse.mm(self.adjacency_matrix, x) \
  23. if self.adjacency_matrix.is_sparse \
  24. else torch.mm(self.adjacency_matrix, x)
  25. return x
  26. class DropoutGraphConvActivation(torch.nn.Module):
  27. def __init__(self, input_dim: int, output_dim: int,
  28. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  29. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  30. **kwargs) -> None:
  31. super().__init__(**kwargs)
  32. self.input_dim = input_dim
  33. self.output_dim = output_dim
  34. self.adjacency_matrix = adjacency_matrix
  35. self.keep_prob = keep_prob
  36. self.activation = activation
  37. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  38. def forward(self, x: torch.Tensor) -> torch.Tensor:
  39. # pdb.set_trace()
  40. x = dropout(x, self.keep_prob)
  41. x = self.graph_conv(x)
  42. x = self.activation(x)
  43. return x
  44. def clone(self, adjacency_matrix) -> 'DropoutGraphConvActivation':
  45. res = DropoutGraphConvActivation(self.input_dim,
  46. self.output_dim, adjacency_matrix, self.keep_prob,
  47. self.activation)
  48. res.graph_conv.weight = self.graph_conv.weight
  49. return res