|
|
@@ -46,8 +46,56 @@ class SparseMultiDGCA(torch.nn.Module): |
|
|
|
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
out = torch.zeros(len(x), output_dim, dtype=x.dtype)
|
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
|
|
|
|
for f in self.sparse_dgca:
|
|
|
|
out += f(x)
|
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class GraphConv(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels,
|
|
|
|
adjacency_matrix, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.weight = init_glorot(in_channels, out_channels)
|
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.mm(x, self.weight)
|
|
|
|
x = torch.mm(self.adjacency_matrix, x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class DropoutGraphConvActivation(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrix, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.nn.functional.dropout(x, 1.-self.keep_prob)
|
|
|
|
x = self.graph_conv(x)
|
|
|
|
x = self.activation(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class MultiDGCA(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrices, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
|
|
|
|
for f in self.dgca:
|
|
|
|
out += f(x)
|
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1)
|
|
|
|
return out
|