IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.8KB

  1. import torch
  2. from .dropout import dropout_sparse
  3. from .weights import init_glorot
  4. class SparseGraphConv(torch.nn.Module):
  5. """Convolution layer for sparse inputs."""
  6. def __init__(self, in_channels, out_channels,
  7. adjacency_matrix, **kwargs):
  8. super().__init__(**kwargs)
  9. self.in_channels = in_channels
  10. self.out_channels = out_channels
  11. self.weight = init_glorot(in_channels, out_channels)
  12. self.adjacency_matrix = adjacency_matrix
  13. def forward(self, x):
  14. x = torch.sparse.mm(x, self.weight)
  15. x = torch.sparse.mm(self.adjacency_matrix, x)
  16. return x
  17. class SparseDropoutGraphConvActivation(torch.nn.Module):
  18. def __init__(self, input_dim, output_dim,
  19. adjacency_matrix, keep_prob=1.,
  20. activation=torch.nn.functional.relu,
  21. **kwargs):
  22. super().__init__(**kwargs)
  23. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim)
  24. self.keep_prob = keep_prob
  25. self.activation = activation
  26. def forward(self, x):
  27. x = dropout_sparse(x, self.keep_prob)
  28. x = self.sparse_graph_conv(x)
  29. x = self.activation(x)
  30. return x
  31. class SparseMultiDGCA(torch.nn.Module):
  32. def __init__(self, input_dim, output_dim,
  33. adjacency_matrices, keep_prob=1.,
  34. activation=torch.nn.functional.relu,
  35. **kwargs):
  36. super().__init__(**kwargs)
  37. self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  38. def forward(self, x):
  39. out = torch.zeros(len(x), output_dim, dtype=x.dtype)
  40. for f in self.sparse_dgca:
  41. out += f(x)
  42. out = torch.nn.functional.normalize(out, p=2, dim=1)
  43. return out