IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.4KB

  1. import torch
  2. from .dropout import dropout_sparse, \
  3. dropout
  4. from .weights import init_glorot
  5. class SparseGraphConv(torch.nn.Module):
  6. """Convolution layer for sparse inputs."""
  7. def __init__(self, in_channels, out_channels,
  8. adjacency_matrix, **kwargs):
  9. super().__init__(**kwargs)
  10. self.in_channels = in_channels
  11. self.out_channels = out_channels
  12. self.weight = init_glorot(in_channels, out_channels)
  13. self.adjacency_matrix = adjacency_matrix
  14. def forward(self, x):
  15. x = torch.sparse.mm(x, self.weight)
  16. x = torch.sparse.mm(self.adjacency_matrix, x)
  17. return x
  18. class SparseDropoutGraphConvActivation(torch.nn.Module):
  19. def __init__(self, input_dim, output_dim,
  20. adjacency_matrix, keep_prob=1.,
  21. activation=torch.nn.functional.relu,
  22. **kwargs):
  23. super().__init__(**kwargs)
  24. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  25. self.keep_prob = keep_prob
  26. self.activation = activation
  27. def forward(self, x):
  28. x = dropout_sparse(x, self.keep_prob)
  29. x = self.sparse_graph_conv(x)
  30. x = self.activation(x)
  31. return x
  32. class SparseMultiDGCA(torch.nn.Module):
  33. def __init__(self, input_dim, output_dim,
  34. adjacency_matrices, keep_prob=1.,
  35. activation=torch.nn.functional.relu,
  36. **kwargs):
  37. super().__init__(**kwargs)
  38. self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  39. def forward(self, x):
  40. out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
  41. for f in self.sparse_dgca:
  42. out += f(x)
  43. out = torch.nn.functional.normalize(out, p=2, dim=1)
  44. return out
  45. class GraphConv(torch.nn.Module):
  46. def __init__(self, in_channels, out_channels,
  47. adjacency_matrix, **kwargs):
  48. super().__init__(**kwargs)
  49. self.in_channels = in_channels
  50. self.out_channels = out_channels
  51. self.weight = init_glorot(in_channels, out_channels)
  52. self.adjacency_matrix = adjacency_matrix
  53. def forward(self, x):
  54. x = torch.mm(x, self.weight)
  55. x = torch.mm(self.adjacency_matrix, x)
  56. return x
  57. class DropoutGraphConvActivation(torch.nn.Module):
  58. def __init__(self, input_dim, output_dim,
  59. adjacency_matrix, keep_prob=1.,
  60. activation=torch.nn.functional.relu,
  61. **kwargs):
  62. super().__init__(**kwargs)
  63. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  64. self.keep_prob = keep_prob
  65. self.activation = activation
  66. def forward(self, x):
  67. x = dropout(x, keep_prob=self.keep_prob)
  68. x = self.graph_conv(x)
  69. x = self.activation(x)
  70. return x
  71. class MultiDGCA(torch.nn.Module):
  72. def __init__(self, input_dim, output_dim,
  73. adjacency_matrices, keep_prob=1.,
  74. activation=torch.nn.functional.relu,
  75. **kwargs):
  76. super().__init__(**kwargs)
  77. self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  78. def forward(self, x):
  79. out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
  80. for f in self.dgca:
  81. out += f(x)
  82. out = torch.nn.functional.normalize(out, p=2, dim=1)
  83. return out