|
- import torch
- from .dropout import dropout_sparse
- from .weights import init_glorot
-
-
- class SparseGraphConv(torch.nn.Module):
- """Convolution layer for sparse inputs."""
- def __init__(self, in_channels, out_channels,
- adjacency_matrix, **kwargs):
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = init_glorot(in_channels, out_channels)
- self.adjacency_matrix = adjacency_matrix
-
-
- def forward(self, x):
- x = torch.sparse.mm(x, self.weight)
- x = torch.sparse.mm(self.adjacency_matrix, x)
- return x
-
-
- class SparseDropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrix, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim)
- self.keep_prob = keep_prob
- self.activation = activation
-
- def forward(self, x):
- x = dropout_sparse(x, self.keep_prob)
- x = self.sparse_graph_conv(x)
- x = self.activation(x)
- return x
-
-
- class SparseMultiDGCA(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrices, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
-
- def forward(self, x):
- out = torch.zeros(len(x), output_dim, dtype=x.dtype)
- for f in self.sparse_dgca:
- out += f(x)
- out = torch.nn.functional.normalize(out, p=2, dim=1)
- return out
|