import torch from .dropout import dropout_sparse from .weights import init_glorot class SparseGraphConv(torch.nn.Module): """Convolution layer for sparse inputs.""" def __init__(self, in_channels, out_channels, adjacency_matrix, **kwargs): super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = init_glorot(in_channels, out_channels) self.adjacency_matrix = adjacency_matrix def forward(self, x): x = torch.sparse.mm(x, self.weight) x = torch.sparse.mm(self.adjacency_matrix, x) return x class SparseDropoutGraphConvActivation(torch.nn.Module): def __init__(self, input_dim, output_dim, adjacency_matrix, keep_prob=1., activation=torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim) self.keep_prob = keep_prob self.activation = activation def forward(self, x): x = dropout_sparse(x, self.keep_prob) x = self.sparse_graph_conv(x) x = self.activation(x) return x class SparseMultiDGCA(torch.nn.Module): def __init__(self, input_dim, output_dim, adjacency_matrices, keep_prob=1., activation=torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] def forward(self, x): out = torch.zeros(len(x), output_dim, dtype=x.dtype) for f in self.sparse_dgca: out += f(x) out = torch.nn.functional.normalize(out, p=2, dim=1) return out