IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.5KB

  1. import torch
  2. from .dropout import dropout_sparse, \
  3. dropout
  4. from .weights import init_glorot
  5. class SparseGraphConv(torch.nn.Module):
  6. """Convolution layer for sparse inputs."""
  7. def __init__(self, in_channels, out_channels,
  8. adjacency_matrix, **kwargs):
  9. super().__init__(**kwargs)
  10. self.in_channels = in_channels
  11. self.out_channels = out_channels
  12. self.weight = init_glorot(in_channels, out_channels)
  13. self.adjacency_matrix = adjacency_matrix
  14. def forward(self, x):
  15. x = torch.sparse.mm(x, self.weight)
  16. x = torch.sparse.mm(self.adjacency_matrix, x)
  17. return x
  18. class SparseDropoutGraphConvActivation(torch.nn.Module):
  19. def __init__(self, input_dim, output_dim,
  20. adjacency_matrix, keep_prob=1.,
  21. activation=torch.nn.functional.relu,
  22. **kwargs):
  23. super().__init__(**kwargs)
  24. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  25. self.keep_prob = keep_prob
  26. self.activation = activation
  27. def forward(self, x):
  28. x = dropout_sparse(x, self.keep_prob)
  29. x = self.sparse_graph_conv(x)
  30. x = self.activation(x)
  31. return x
  32. class SparseMultiDGCA(torch.nn.Module):
  33. def __init__(self, input_dim, output_dim,
  34. adjacency_matrices, keep_prob=1.,
  35. activation=torch.nn.functional.relu,
  36. **kwargs):
  37. super().__init__(**kwargs)
  38. self.output_dim = output_dim
  39. self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  40. def forward(self, x):
  41. out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
  42. for f in self.sparse_dgca:
  43. out += f(x)
  44. out = torch.nn.functional.normalize(out, p=2, dim=1)
  45. return out
  46. class GraphConv(torch.nn.Module):
  47. def __init__(self, in_channels, out_channels,
  48. adjacency_matrix, **kwargs):
  49. super().__init__(**kwargs)
  50. self.in_channels = in_channels
  51. self.out_channels = out_channels
  52. self.weight = init_glorot(in_channels, out_channels)
  53. self.adjacency_matrix = adjacency_matrix
  54. def forward(self, x):
  55. x = torch.mm(x, self.weight)
  56. x = torch.mm(self.adjacency_matrix, x)
  57. return x
  58. class DropoutGraphConvActivation(torch.nn.Module):
  59. def __init__(self, input_dim, output_dim,
  60. adjacency_matrix, keep_prob=1.,
  61. activation=torch.nn.functional.relu,
  62. **kwargs):
  63. super().__init__(**kwargs)
  64. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  65. self.keep_prob = keep_prob
  66. self.activation = activation
  67. def forward(self, x):
  68. x = dropout(x, keep_prob=self.keep_prob)
  69. x = self.graph_conv(x)
  70. x = self.activation(x)
  71. return x
  72. class MultiDGCA(torch.nn.Module):
  73. def __init__(self, input_dim, output_dim,
  74. adjacency_matrices, keep_prob=1.,
  75. activation=torch.nn.functional.relu,
  76. **kwargs):
  77. super().__init__(**kwargs)
  78. self.output_dim = output_dim
  79. self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  80. def forward(self, x):
  81. out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
  82. for f in self.dgca:
  83. out += f(x)
  84. out = torch.nn.functional.normalize(out, p=2, dim=1)
  85. return out