|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import torch
- from .dropout import dropout_sparse, \
- dropout
- from .weights import init_glorot
-
-
- class SparseGraphConv(torch.nn.Module):
- """Convolution layer for sparse inputs."""
- def __init__(self, in_channels, out_channels,
- adjacency_matrix, **kwargs):
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = init_glorot(in_channels, out_channels)
- self.adjacency_matrix = adjacency_matrix
-
-
- def forward(self, x):
- x = torch.sparse.mm(x, self.weight)
- x = torch.sparse.mm(self.adjacency_matrix, x)
- return x
-
-
- class SparseDropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrix, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
- self.keep_prob = keep_prob
- self.activation = activation
-
- def forward(self, x):
- x = dropout_sparse(x, self.keep_prob)
- x = self.sparse_graph_conv(x)
- x = self.activation(x)
- return x
-
-
- class SparseMultiDGCA(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrices, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.output_dim = output_dim
- self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
-
- def forward(self, x):
- out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
- for f in self.sparse_dgca:
- out += f(x)
- out = torch.nn.functional.normalize(out, p=2, dim=1)
- return out
-
-
- class GraphConv(torch.nn.Module):
- def __init__(self, in_channels, out_channels,
- adjacency_matrix, **kwargs):
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = init_glorot(in_channels, out_channels)
- self.adjacency_matrix = adjacency_matrix
-
- def forward(self, x):
- x = torch.mm(x, self.weight)
- x = torch.mm(self.adjacency_matrix, x)
- return x
-
-
- class DropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrix, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
- self.keep_prob = keep_prob
- self.activation = activation
-
- def forward(self, x):
- x = dropout(x, keep_prob=self.keep_prob)
- x = self.graph_conv(x)
- x = self.activation(x)
- return x
-
-
- class MultiDGCA(torch.nn.Module):
- def __init__(self, input_dim, output_dim,
- adjacency_matrices, keep_prob=1.,
- activation=torch.nn.functional.relu,
- **kwargs):
- super().__init__(**kwargs)
- self.output_dim = output_dim
- self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
-
- def forward(self, x):
- out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
- for f in self.dgca:
- out += f(x)
- out = torch.nn.functional.normalize(out, p=2, dim=1)
- return out
|