@@ -1,346 +0,0 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
""" | |||||
This module implements the basic convolutional blocks of Decagon. | |||||
Just as a quick reminder, the basic convolution formula here is: | |||||
y = A * (x * W) | |||||
where: | |||||
W is a weight matrix | |||||
A is an adjacency matrix | |||||
x is a matrix of latent representations of a particular type of neighbors. | |||||
As we have x here twice, a trick is obviously necessary for this to work. | |||||
A must be previously normalized with: | |||||
c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||||
or | |||||
c_{r}^{i} = 1/|N_{r}^{i}| | |||||
Let's work through this step by step to convince ourselves that the | |||||
formula is correct. | |||||
x = [ | |||||
[0, 1, 0, 1], | |||||
[1, 1, 1, 0], | |||||
[0, 0, 0, 1] | |||||
] | |||||
W = [ | |||||
[0, 1], | |||||
[1, 0], | |||||
[0.5, 0.5], | |||||
[0.25, 0.75] | |||||
] | |||||
A = [ | |||||
[0, 1, 0], | |||||
[1, 0, 1], | |||||
[0, 1, 0] | |||||
] | |||||
so the graph looks like this: | |||||
(0) -- (1) -- (2) | |||||
and therefore the representations in the next layer should be: | |||||
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} | |||||
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + | |||||
c_{r}^{1} * h_{1}^{k} | |||||
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} | |||||
In actual Decagon code we can see that that latter part propagating directly | |||||
the old representation is gone. I will try to do the same for now. | |||||
So we have to only take care of: | |||||
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W | |||||
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} | |||||
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W | |||||
If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: | |||||
A = A + eye(len(A)) | |||||
rowsum = A.sum(1) | |||||
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | |||||
A = dot(A, deg_mat_inv_sqrt) | |||||
A = A.transpose() | |||||
A = A.dot(deg_mat_inv_sqrt) | |||||
Let's see what gives in our case: | |||||
A = A + eye(len(A)) | |||||
[ | |||||
[1, 1, 0], | |||||
[1, 1, 1], | |||||
[0, 1, 1] | |||||
] | |||||
rowsum = A.sum(1) | |||||
[2, 3, 2] | |||||
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | |||||
[ | |||||
[1./sqrt(2), 0, 0], | |||||
[0, 1./sqrt(3), 0], | |||||
[0, 0, 1./sqrt(2)] | |||||
] | |||||
A = dot(A, deg_mat_inv_sqrt) | |||||
[ | |||||
[ 1/sqrt(2), 1/sqrt(3), 0 ], | |||||
[ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], | |||||
[ 0, 1/sqrt(3), 1/sqrt(2) ] | |||||
] | |||||
A = A.transpose() | |||||
[ | |||||
[ 1/sqrt(2), 1/sqrt(2), 0 ], | |||||
[ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], | |||||
[ 0, 1/sqrt(2), 1/sqrt(2) ] | |||||
] | |||||
A = A.dot(deg_mat_inv_sqrt) | |||||
[ | |||||
[ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], | |||||
[ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], | |||||
[ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], | |||||
] | |||||
thus: | |||||
[ | |||||
[0.5 , 0.40824829, 0. ], | |||||
[0.40824829, 0.33333333, 0.40824829], | |||||
[0. , 0.40824829, 0.5 ] | |||||
] | |||||
This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. | |||||
Then, we get back to the main calculation: | |||||
y = x * W | |||||
y = A * y | |||||
y = x * W | |||||
[ | |||||
[ 1.25, 0.75 ], | |||||
[ 1.5 , 1.5 ], | |||||
[ 0.25, 0.75 ] | |||||
] | |||||
y = A * y | |||||
[ | |||||
0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], | |||||
0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], | |||||
0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] | |||||
] | |||||
that is: | |||||
[ | |||||
[1.23737243, 0.98737244], | |||||
[1.11237243, 1.11237243], | |||||
[0.73737244, 0.98737244] | |||||
]. | |||||
All checks out nicely, good. | |||||
""" | |||||
import torch | |||||
from .dropout import dropout_sparse, \ | |||||
dropout | |||||
from .weights import init_glorot | |||||
from typing import List, Callable | |||||
class SparseGraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse inputs.""" | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.sparse.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class SparseDropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrix = adjacency_matrix | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout_sparse(x, self.keep_prob) | |||||
x = self.sparse_graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class SparseMultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim: List[int], output_dim: int, | |||||
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.sparse_dgca = None | |||||
self.build() | |||||
def build(self): | |||||
if len(self.input_dim) != len(self.adjacency_matrices): | |||||
raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
self.sparse_dgca = [] | |||||
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |||||
if not isinstance(x, list): | |||||
raise ValueError('x must be a list of tensors') | |||||
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
for i, f in enumerate(self.sparse_dgca): | |||||
out += f(x[i]) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out | |||||
class DenseGraphConv(torch.nn.Module): | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.mm(x, self.weight) | |||||
x = torch.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class DenseDropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout(x, keep_prob=self.keep_prob) | |||||
x = self.graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class DenseMultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim: List[int], output_dim: int, | |||||
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.dgca = None | |||||
self.build() | |||||
def build(self): | |||||
if len(self.input_dim) != len(self.adjacency_matrices): | |||||
raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
self.dgca = [] | |||||
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
if not isinstance(x, list): | |||||
raise ValueError('x must be a list of tensors') | |||||
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
for i, f in enumerate(self.dgca): | |||||
out += f(x[i]) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out | |||||
class GraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse AND dense inputs.""" | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.sparse.mm(x, self.weight) \ | |||||
if x.is_sparse \ | |||||
else torch.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||||
if self.adjacency_matrix.is_sparse \ | |||||
else torch.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class DropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrix = adjacency_matrix | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout_sparse(x, self.keep_prob) \ | |||||
if x.is_sparse \ | |||||
else dropout(x, self.keep_prob) | |||||
x = self.graph_conv(x) | |||||
x = self.activation(x) | |||||
return x |
@@ -0,0 +1,168 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
""" | |||||
This module implements the basic convolutional blocks of Decagon. | |||||
Just as a quick reminder, the basic convolution formula here is: | |||||
y = A * (x * W) | |||||
where: | |||||
W is a weight matrix | |||||
A is an adjacency matrix | |||||
x is a matrix of latent representations of a particular type of neighbors. | |||||
As we have x here twice, a trick is obviously necessary for this to work. | |||||
A must be previously normalized with: | |||||
c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | |||||
or | |||||
c_{r}^{i} = 1/|N_{r}^{i}| | |||||
Let's work through this step by step to convince ourselves that the | |||||
formula is correct. | |||||
x = [ | |||||
[0, 1, 0, 1], | |||||
[1, 1, 1, 0], | |||||
[0, 0, 0, 1] | |||||
] | |||||
W = [ | |||||
[0, 1], | |||||
[1, 0], | |||||
[0.5, 0.5], | |||||
[0.25, 0.75] | |||||
] | |||||
A = [ | |||||
[0, 1, 0], | |||||
[1, 0, 1], | |||||
[0, 1, 0] | |||||
] | |||||
so the graph looks like this: | |||||
(0) -- (1) -- (2) | |||||
and therefore the representations in the next layer should be: | |||||
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} | |||||
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + | |||||
c_{r}^{1} * h_{1}^{k} | |||||
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} | |||||
In actual Decagon code we can see that that latter part propagating directly | |||||
the old representation is gone. I will try to do the same for now. | |||||
So we have to only take care of: | |||||
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W | |||||
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} | |||||
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W | |||||
If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: | |||||
A = A + eye(len(A)) | |||||
rowsum = A.sum(1) | |||||
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | |||||
A = dot(A, deg_mat_inv_sqrt) | |||||
A = A.transpose() | |||||
A = A.dot(deg_mat_inv_sqrt) | |||||
Let's see what gives in our case: | |||||
A = A + eye(len(A)) | |||||
[ | |||||
[1, 1, 0], | |||||
[1, 1, 1], | |||||
[0, 1, 1] | |||||
] | |||||
rowsum = A.sum(1) | |||||
[2, 3, 2] | |||||
deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | |||||
[ | |||||
[1./sqrt(2), 0, 0], | |||||
[0, 1./sqrt(3), 0], | |||||
[0, 0, 1./sqrt(2)] | |||||
] | |||||
A = dot(A, deg_mat_inv_sqrt) | |||||
[ | |||||
[ 1/sqrt(2), 1/sqrt(3), 0 ], | |||||
[ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], | |||||
[ 0, 1/sqrt(3), 1/sqrt(2) ] | |||||
] | |||||
A = A.transpose() | |||||
[ | |||||
[ 1/sqrt(2), 1/sqrt(2), 0 ], | |||||
[ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], | |||||
[ 0, 1/sqrt(2), 1/sqrt(2) ] | |||||
] | |||||
A = A.dot(deg_mat_inv_sqrt) | |||||
[ | |||||
[ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], | |||||
[ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], | |||||
[ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], | |||||
] | |||||
thus: | |||||
[ | |||||
[0.5 , 0.40824829, 0. ], | |||||
[0.40824829, 0.33333333, 0.40824829], | |||||
[0. , 0.40824829, 0.5 ] | |||||
] | |||||
This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. | |||||
Then, we get back to the main calculation: | |||||
y = x * W | |||||
y = A * y | |||||
y = x * W | |||||
[ | |||||
[ 1.25, 0.75 ], | |||||
[ 1.5 , 1.5 ], | |||||
[ 0.25, 0.75 ] | |||||
] | |||||
y = A * y | |||||
[ | |||||
0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], | |||||
0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], | |||||
0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] | |||||
] | |||||
that is: | |||||
[ | |||||
[1.23737243, 0.98737244], | |||||
[1.11237243, 1.11237243], | |||||
[0.73737244, 0.98737244] | |||||
]. | |||||
All checks out nicely, good. | |||||
""" | |||||
from .dense import * | |||||
from .sparse import * | |||||
from .universal import * |
@@ -0,0 +1,73 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from .dropout import dropout | |||||
from .weights import init_glorot | |||||
from typing import List, Callable | |||||
class DenseGraphConv(torch.nn.Module): | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.mm(x, self.weight) | |||||
x = torch.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class DenseDropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout(x, keep_prob=self.keep_prob) | |||||
x = self.graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class DenseMultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim: List[int], output_dim: int, | |||||
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.dgca = None | |||||
self.build() | |||||
def build(self): | |||||
if len(self.input_dim) != len(self.adjacency_matrices): | |||||
raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
self.dgca = [] | |||||
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
if not isinstance(x, list): | |||||
raise ValueError('x must be a list of tensors') | |||||
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
for i, f in enumerate(self.dgca): | |||||
out += f(x[i]) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out |
@@ -0,0 +1,78 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from .dropout import dropout_sparse | |||||
from .weights import init_glorot | |||||
from typing import List, Callable | |||||
class SparseGraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse inputs.""" | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.sparse.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class SparseDropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrix = adjacency_matrix | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout_sparse(x, self.keep_prob) | |||||
x = self.sparse_graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class SparseMultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim: List[int], output_dim: int, | |||||
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.sparse_dgca = None | |||||
self.build() | |||||
def build(self): | |||||
if len(self.input_dim) != len(self.adjacency_matrices): | |||||
raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
self.sparse_dgca = [] | |||||
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |||||
if not isinstance(x, list): | |||||
raise ValueError('x must be a list of tensors') | |||||
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
for i, f in enumerate(self.sparse_dgca): | |||||
out += f(x[i]) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out |
@@ -0,0 +1,85 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from .dropout import dropout_sparse, \ | |||||
dropout | |||||
from .weights import init_glorot | |||||
from typing import List, Callable | |||||
class GraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse AND dense inputs.""" | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.sparse.mm(x, self.weight) \ | |||||
if x.is_sparse \ | |||||
else torch.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||||
if self.adjacency_matrix.is_sparse \ | |||||
else torch.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class DropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrix = adjacency_matrix | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout_sparse(x, self.keep_prob) \ | |||||
if x.is_sparse \ | |||||
else dropout(x, self.keep_prob) | |||||
x = self.graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class MultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim: List[int], output_dim: int, | |||||
adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.dgca = None | |||||
self.build() | |||||
def build(self): | |||||
if len(self.input_dim) != len(self.adjacency_matrices): | |||||
raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
self.dgca = [] | |||||
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
if not isinstance(x, list): | |||||
raise ValueError('x must be a list of tensors') | |||||
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
for i, f in enumerate(self.dgca): | |||||
out += f(x[i]) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out |