|
|
@@ -1,104 +1,272 @@ |
|
|
|
import torch
|
|
|
|
from .dropout import dropout_sparse, \
|
|
|
|
dropout
|
|
|
|
from .weights import init_glorot
|
|
|
|
|
|
|
|
|
|
|
|
class SparseGraphConv(torch.nn.Module):
|
|
|
|
"""Convolution layer for sparse inputs."""
|
|
|
|
def __init__(self, in_channels, out_channels,
|
|
|
|
adjacency_matrix, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.weight = init_glorot(in_channels, out_channels)
|
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.sparse.mm(x, self.weight)
|
|
|
|
x = torch.sparse.mm(self.adjacency_matrix, x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class SparseDropoutGraphConvActivation(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrix, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = dropout_sparse(x, self.keep_prob)
|
|
|
|
x = self.sparse_graph_conv(x)
|
|
|
|
x = self.activation(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class SparseMultiDGCA(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrices, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
|
|
|
|
for f in self.sparse_dgca:
|
|
|
|
out += f(x)
|
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class GraphConv(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels,
|
|
|
|
adjacency_matrix, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.weight = init_glorot(in_channels, out_channels)
|
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.mm(x, self.weight)
|
|
|
|
x = torch.mm(self.adjacency_matrix, x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class DropoutGraphConvActivation(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrix, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = dropout(x, keep_prob=self.keep_prob)
|
|
|
|
x = self.graph_conv(x)
|
|
|
|
x = self.activation(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class MultiDGCA(torch.nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim,
|
|
|
|
adjacency_matrices, keep_prob=1.,
|
|
|
|
activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
|
|
|
|
for f in self.dgca:
|
|
|
|
out += f(x)
|
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1)
|
|
|
|
return out
|
|
|
|
# |
|
|
|
# Copyright (C) Stanislaw Adaszewski, 2020 |
|
|
|
# License: GPLv3 |
|
|
|
# |
|
|
|
|
|
|
|
""" |
|
|
|
This module implements the basic convolutional blocks of Decagon. |
|
|
|
Just as a quick reminder, the basic convolution formula here is: |
|
|
|
|
|
|
|
y = A * (x * W) |
|
|
|
|
|
|
|
where: |
|
|
|
|
|
|
|
W is a weight matrix |
|
|
|
A is an adjacency matrix |
|
|
|
x is a matrix of latent representations of a particular type of neighbors. |
|
|
|
|
|
|
|
As we have x here twice, a trick is obviously necessary for this to work. |
|
|
|
A must be previously normalized with: |
|
|
|
|
|
|
|
c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) |
|
|
|
|
|
|
|
or |
|
|
|
|
|
|
|
c_{r}^{i} = 1/|N_{r}^{i}| |
|
|
|
|
|
|
|
Let's work through this step by step to convince ourselves that the |
|
|
|
formula is correct. |
|
|
|
|
|
|
|
x = [ |
|
|
|
[0, 1, 0, 1], |
|
|
|
[1, 1, 1, 0], |
|
|
|
[0, 0, 0, 1] |
|
|
|
] |
|
|
|
|
|
|
|
W = [ |
|
|
|
[0, 1], |
|
|
|
[1, 0], |
|
|
|
[0.5, 0.5], |
|
|
|
[0.25, 0.75] |
|
|
|
] |
|
|
|
|
|
|
|
A = [ |
|
|
|
[0, 1, 0], |
|
|
|
[1, 0, 1], |
|
|
|
[0, 1, 0] |
|
|
|
] |
|
|
|
|
|
|
|
so the graph looks like this: |
|
|
|
|
|
|
|
(0) -- (1) -- (2) |
|
|
|
|
|
|
|
and therefore the representations in the next layer should be: |
|
|
|
|
|
|
|
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} |
|
|
|
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + |
|
|
|
c_{r}^{1} * h_{1}^{k} |
|
|
|
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} |
|
|
|
|
|
|
|
In actual Decagon code we can see that that latter part propagating directly |
|
|
|
the old representation is gone. I will try to do the same for now. |
|
|
|
|
|
|
|
So we have to only take care of: |
|
|
|
|
|
|
|
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W |
|
|
|
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} |
|
|
|
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W |
|
|
|
|
|
|
|
If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: |
|
|
|
|
|
|
|
A = A + eye(len(A)) |
|
|
|
rowsum = A.sum(1) |
|
|
|
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) |
|
|
|
A = dot(A, deg_mat_inv_sqrt) |
|
|
|
A = A.transpose() |
|
|
|
A = A.dot(deg_mat_inv_sqrt) |
|
|
|
|
|
|
|
Let's see what gives in our case: |
|
|
|
|
|
|
|
A = A + eye(len(A)) |
|
|
|
|
|
|
|
[ |
|
|
|
[1, 1, 0], |
|
|
|
[1, 1, 1], |
|
|
|
[0, 1, 1] |
|
|
|
] |
|
|
|
|
|
|
|
rowsum = A.sum(1) |
|
|
|
|
|
|
|
[2, 3, 2] |
|
|
|
|
|
|
|
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) |
|
|
|
|
|
|
|
[ |
|
|
|
[1./sqrt(2), 0, 0], |
|
|
|
[0, 1./sqrt(3), 0], |
|
|
|
[0, 0, 1./sqrt(2)] |
|
|
|
] |
|
|
|
|
|
|
|
A = dot(A, deg_mat_inv_sqrt) |
|
|
|
|
|
|
|
[ |
|
|
|
[ 1/sqrt(2), 1/sqrt(3), 0 ], |
|
|
|
[ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], |
|
|
|
[ 0, 1/sqrt(3), 1/sqrt(2) ] |
|
|
|
] |
|
|
|
|
|
|
|
A = A.transpose() |
|
|
|
|
|
|
|
[ |
|
|
|
[ 1/sqrt(2), 1/sqrt(2), 0 ], |
|
|
|
[ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], |
|
|
|
[ 0, 1/sqrt(2), 1/sqrt(2) ] |
|
|
|
] |
|
|
|
|
|
|
|
A = A.dot(deg_mat_inv_sqrt) |
|
|
|
|
|
|
|
[ |
|
|
|
[ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], |
|
|
|
[ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], |
|
|
|
[ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], |
|
|
|
] |
|
|
|
|
|
|
|
thus: |
|
|
|
|
|
|
|
[ |
|
|
|
[0.5 , 0.40824829, 0. ], |
|
|
|
[0.40824829, 0.33333333, 0.40824829], |
|
|
|
[0. , 0.40824829, 0.5 ] |
|
|
|
] |
|
|
|
|
|
|
|
This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. |
|
|
|
|
|
|
|
Then, we get back to the main calculation: |
|
|
|
|
|
|
|
y = x * W |
|
|
|
y = A * y |
|
|
|
|
|
|
|
y = x * W |
|
|
|
|
|
|
|
[ |
|
|
|
[ 1.25, 0.75 ], |
|
|
|
[ 1.5 , 1.5 ], |
|
|
|
[ 0.25, 0.75 ] |
|
|
|
] |
|
|
|
|
|
|
|
y = A * y |
|
|
|
|
|
|
|
[ |
|
|
|
0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], |
|
|
|
0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], |
|
|
|
0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] |
|
|
|
] |
|
|
|
|
|
|
|
that is: |
|
|
|
|
|
|
|
[ |
|
|
|
[1.23737243, 0.98737244], |
|
|
|
[1.11237243, 1.11237243], |
|
|
|
[0.73737244, 0.98737244] |
|
|
|
]. |
|
|
|
|
|
|
|
All checks out nicely, good. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .dropout import dropout_sparse, \ |
|
|
|
dropout |
|
|
|
from .weights import init_glorot |
|
|
|
from typing import List, Callable |
|
|
|
|
|
|
|
|
|
|
|
class SparseGraphConv(torch.nn.Module): |
|
|
|
"""Convolution layer for sparse inputs.""" |
|
|
|
def __init__(self, in_channels: int, out_channels: int, |
|
|
|
adjacency_matrix: torch.Tensor, **kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
|
|
self.weight = init_glorot(in_channels, out_channels) |
|
|
|
self.adjacency_matrix = adjacency_matrix |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = torch.sparse.mm(x, self.weight) |
|
|
|
x = torch.sparse.mm(self.adjacency_matrix, x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SparseDropoutGraphConvActivation(torch.nn.Module): |
|
|
|
def __init__(self, input_dim: int, output_dim: int, |
|
|
|
adjacency_matrix: torch.Tensor, keep_prob: float=1., |
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
|
**kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) |
|
|
|
self.keep_prob = keep_prob |
|
|
|
self.activation = activation |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = dropout_sparse(x, self.keep_prob) |
|
|
|
x = self.sparse_graph_conv(x) |
|
|
|
x = self.activation(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SparseMultiDGCA(torch.nn.Module): |
|
|
|
def __init__(self, input_dim: List[int], output_dim: int, |
|
|
|
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., |
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
|
**kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.output_dim = output_dim |
|
|
|
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] |
|
|
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) |
|
|
|
for f in self.sparse_dgca: |
|
|
|
out += f(x) |
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class GraphConv(torch.nn.Module): |
|
|
|
def __init__(self, in_channels: int, out_channels: int, |
|
|
|
adjacency_matrix: torch.Tensor, **kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
|
|
self.weight = init_glorot(in_channels, out_channels) |
|
|
|
self.adjacency_matrix = adjacency_matrix |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = torch.mm(x, self.weight) |
|
|
|
x = torch.mm(self.adjacency_matrix, x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class DropoutGraphConvActivation(torch.nn.Module): |
|
|
|
def __init__(self, input_dim: int, output_dim: int, |
|
|
|
adjacency_matrix: torch.Tensor, keep_prob: float=1., |
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
|
**kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) |
|
|
|
self.keep_prob = keep_prob |
|
|
|
self.activation = activation |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = dropout(x, keep_prob=self.keep_prob) |
|
|
|
x = self.graph_conv(x) |
|
|
|
x = self.activation(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class MultiDGCA(torch.nn.Module): |
|
|
|
def __init__(self, input_dim: List[int], output_dim: int, |
|
|
|
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., |
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
|
**kwargs) -> None: |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.output_dim = output_dim |
|
|
|
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] |
|
|
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) |
|
|
|
for f in self.dgca: |
|
|
|
out += f(x) |
|
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1) |
|
|
|
return out |