From c74555fef51aec638d50b198d002cbd1bdc84e01 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 29 May 2020 10:46:57 +0200 Subject: [PATCH] Refactor convolve into 3 separate modules. --- src/decagon_pytorch/convolve.py | 346 ---------------------- src/decagon_pytorch/convolve/__init__.py | 168 +++++++++++ src/decagon_pytorch/convolve/dense.py | 73 +++++ src/decagon_pytorch/convolve/sparse.py | 78 +++++ src/decagon_pytorch/convolve/universal.py | 85 ++++++ 5 files changed, 404 insertions(+), 346 deletions(-) delete mode 100644 src/decagon_pytorch/convolve.py create mode 100644 src/decagon_pytorch/convolve/__init__.py create mode 100644 src/decagon_pytorch/convolve/dense.py create mode 100644 src/decagon_pytorch/convolve/sparse.py create mode 100644 src/decagon_pytorch/convolve/universal.py diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py deleted file mode 100644 index 4048521..0000000 --- a/src/decagon_pytorch/convolve.py +++ /dev/null @@ -1,346 +0,0 @@ -# -# Copyright (C) Stanislaw Adaszewski, 2020 -# License: GPLv3 -# - -""" -This module implements the basic convolutional blocks of Decagon. -Just as a quick reminder, the basic convolution formula here is: - -y = A * (x * W) - -where: - -W is a weight matrix -A is an adjacency matrix -x is a matrix of latent representations of a particular type of neighbors. - -As we have x here twice, a trick is obviously necessary for this to work. -A must be previously normalized with: - -c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) - -or - -c_{r}^{i} = 1/|N_{r}^{i}| - -Let's work through this step by step to convince ourselves that the -formula is correct. - -x = [ - [0, 1, 0, 1], - [1, 1, 1, 0], - [0, 0, 0, 1] -] - -W = [ - [0, 1], - [1, 0], - [0.5, 0.5], - [0.25, 0.75] -] - -A = [ - [0, 1, 0], - [1, 0, 1], - [0, 1, 0] -] - -so the graph looks like this: - -(0) -- (1) -- (2) - -and therefore the representations in the next layer should be: - -h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} -h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + - c_{r}^{1} * h_{1}^{k} -h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} - -In actual Decagon code we can see that that latter part propagating directly -the old representation is gone. I will try to do the same for now. - -So we have to only take care of: - -h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W -h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} -h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W - -If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: - -A = A + eye(len(A)) -rowsum = A.sum(1) -deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) -A = dot(A, deg_mat_inv_sqrt) -A = A.transpose() -A = A.dot(deg_mat_inv_sqrt) - -Let's see what gives in our case: - -A = A + eye(len(A)) - -[ - [1, 1, 0], - [1, 1, 1], - [0, 1, 1] -] - -rowsum = A.sum(1) - -[2, 3, 2] - -deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) - -[ - [1./sqrt(2), 0, 0], - [0, 1./sqrt(3), 0], - [0, 0, 1./sqrt(2)] -] - -A = dot(A, deg_mat_inv_sqrt) - -[ - [ 1/sqrt(2), 1/sqrt(3), 0 ], - [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], - [ 0, 1/sqrt(3), 1/sqrt(2) ] -] - -A = A.transpose() - -[ - [ 1/sqrt(2), 1/sqrt(2), 0 ], - [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], - [ 0, 1/sqrt(2), 1/sqrt(2) ] -] - -A = A.dot(deg_mat_inv_sqrt) - -[ - [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], - [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], - [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], -] - -thus: - -[ - [0.5 , 0.40824829, 0. ], - [0.40824829, 0.33333333, 0.40824829], - [0. , 0.40824829, 0.5 ] -] - -This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. - -Then, we get back to the main calculation: - -y = x * W -y = A * y - -y = x * W - -[ - [ 1.25, 0.75 ], - [ 1.5 , 1.5 ], - [ 0.25, 0.75 ] -] - -y = A * y - -[ - 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], - 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], - 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] -] - -that is: - -[ - [1.23737243, 0.98737244], - [1.11237243, 1.11237243], - [0.73737244, 0.98737244] -]. - -All checks out nicely, good. - -""" - - -import torch -from .dropout import dropout_sparse, \ - dropout -from .weights import init_glorot -from typing import List, Callable - - -class SparseGraphConv(torch.nn.Module): - """Convolution layer for sparse inputs.""" - def __init__(self, in_channels: int, out_channels: int, - adjacency_matrix: torch.Tensor, **kwargs) -> None: - super().__init__(**kwargs) - self.in_channels = in_channels - self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) - self.adjacency_matrix = adjacency_matrix - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = torch.sparse.mm(x, self.weight) - x = torch.sparse.mm(self.adjacency_matrix, x) - return x - - -class SparseDropoutGraphConvActivation(torch.nn.Module): - def __init__(self, input_dim: int, output_dim: int, - adjacency_matrix: torch.Tensor, keep_prob: float=1., - activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, - **kwargs) -> None: - super().__init__(**kwargs) - self.input_dim = input_dim - self.output_dim = output_dim - self.adjacency_matrix = adjacency_matrix - self.keep_prob = keep_prob - self.activation = activation - self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = dropout_sparse(x, self.keep_prob) - x = self.sparse_graph_conv(x) - x = self.activation(x) - return x - - -class SparseMultiDGCA(torch.nn.Module): - def __init__(self, input_dim: List[int], output_dim: int, - adjacency_matrices: List[torch.Tensor], keep_prob: float=1., - activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, - **kwargs) -> None: - super().__init__(**kwargs) - self.input_dim = input_dim - self.output_dim = output_dim - self.adjacency_matrices = adjacency_matrices - self.keep_prob = keep_prob - self.activation = activation - self.sparse_dgca = None - self.build() - - def build(self): - if len(self.input_dim) != len(self.adjacency_matrices): - raise ValueError('input_dim must have the same length as adjacency_matrices') - self.sparse_dgca = [] - for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): - self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) - - def forward(self, x: List[torch.Tensor]) -> torch.Tensor: - if not isinstance(x, list): - raise ValueError('x must be a list of tensors') - out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) - for i, f in enumerate(self.sparse_dgca): - out += f(x[i]) - out = torch.nn.functional.normalize(out, p=2, dim=1) - return out - - -class DenseGraphConv(torch.nn.Module): - def __init__(self, in_channels: int, out_channels: int, - adjacency_matrix: torch.Tensor, **kwargs) -> None: - super().__init__(**kwargs) - self.in_channels = in_channels - self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) - self.adjacency_matrix = adjacency_matrix - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = torch.mm(x, self.weight) - x = torch.mm(self.adjacency_matrix, x) - return x - - -class DenseDropoutGraphConvActivation(torch.nn.Module): - def __init__(self, input_dim: int, output_dim: int, - adjacency_matrix: torch.Tensor, keep_prob: float=1., - activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, - **kwargs) -> None: - super().__init__(**kwargs) - self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) - self.keep_prob = keep_prob - self.activation = activation - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = dropout(x, keep_prob=self.keep_prob) - x = self.graph_conv(x) - x = self.activation(x) - return x - - -class DenseMultiDGCA(torch.nn.Module): - def __init__(self, input_dim: List[int], output_dim: int, - adjacency_matrices: List[torch.Tensor], keep_prob: float=1., - activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, - **kwargs) -> None: - super().__init__(**kwargs) - self.input_dim = input_dim - self.output_dim = output_dim - self.adjacency_matrices = adjacency_matrices - self.keep_prob = keep_prob - self.activation = activation - self.dgca = None - self.build() - - def build(self): - if len(self.input_dim) != len(self.adjacency_matrices): - raise ValueError('input_dim must have the same length as adjacency_matrices') - self.dgca = [] - for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): - self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) - - def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: - if not isinstance(x, list): - raise ValueError('x must be a list of tensors') - out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) - for i, f in enumerate(self.dgca): - out += f(x[i]) - out = torch.nn.functional.normalize(out, p=2, dim=1) - return out - - -class GraphConv(torch.nn.Module): - """Convolution layer for sparse AND dense inputs.""" - def __init__(self, in_channels: int, out_channels: int, - adjacency_matrix: torch.Tensor, **kwargs) -> None: - super().__init__(**kwargs) - self.in_channels = in_channels - self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) - self.adjacency_matrix = adjacency_matrix - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = torch.sparse.mm(x, self.weight) \ - if x.is_sparse \ - else torch.mm(x, self.weight) - x = torch.sparse.mm(self.adjacency_matrix, x) \ - if self.adjacency_matrix.is_sparse \ - else torch.mm(self.adjacency_matrix, x) - return x - - -class DropoutGraphConvActivation(torch.nn.Module): - def __init__(self, input_dim: int, output_dim: int, - adjacency_matrix: torch.Tensor, keep_prob: float=1., - activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, - **kwargs) -> None: - super().__init__(**kwargs) - self.input_dim = input_dim - self.output_dim = output_dim - self.adjacency_matrix = adjacency_matrix - self.keep_prob = keep_prob - self.activation = activation - self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = dropout_sparse(x, self.keep_prob) \ - if x.is_sparse \ - else dropout(x, self.keep_prob) - x = self.graph_conv(x) - x = self.activation(x) - return x diff --git a/src/decagon_pytorch/convolve/__init__.py b/src/decagon_pytorch/convolve/__init__.py new file mode 100644 index 0000000..f3c2e40 --- /dev/null +++ b/src/decagon_pytorch/convolve/__init__.py @@ -0,0 +1,168 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + +""" +This module implements the basic convolutional blocks of Decagon. +Just as a quick reminder, the basic convolution formula here is: + +y = A * (x * W) + +where: + +W is a weight matrix +A is an adjacency matrix +x is a matrix of latent representations of a particular type of neighbors. + +As we have x here twice, a trick is obviously necessary for this to work. +A must be previously normalized with: + +c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) + +or + +c_{r}^{i} = 1/|N_{r}^{i}| + +Let's work through this step by step to convince ourselves that the +formula is correct. + +x = [ + [0, 1, 0, 1], + [1, 1, 1, 0], + [0, 0, 0, 1] +] + +W = [ + [0, 1], + [1, 0], + [0.5, 0.5], + [0.25, 0.75] +] + +A = [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] +] + +so the graph looks like this: + +(0) -- (1) -- (2) + +and therefore the representations in the next layer should be: + +h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} +h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + + c_{r}^{1} * h_{1}^{k} +h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} + +In actual Decagon code we can see that that latter part propagating directly +the old representation is gone. I will try to do the same for now. + +So we have to only take care of: + +h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W +h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + +If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: + +A = A + eye(len(A)) +rowsum = A.sum(1) +deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) +A = dot(A, deg_mat_inv_sqrt) +A = A.transpose() +A = A.dot(deg_mat_inv_sqrt) + +Let's see what gives in our case: + +A = A + eye(len(A)) + +[ + [1, 1, 0], + [1, 1, 1], + [0, 1, 1] +] + +rowsum = A.sum(1) + +[2, 3, 2] + +deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) + +[ + [1./sqrt(2), 0, 0], + [0, 1./sqrt(3), 0], + [0, 0, 1./sqrt(2)] +] + +A = dot(A, deg_mat_inv_sqrt) + +[ + [ 1/sqrt(2), 1/sqrt(3), 0 ], + [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], + [ 0, 1/sqrt(3), 1/sqrt(2) ] +] + +A = A.transpose() + +[ + [ 1/sqrt(2), 1/sqrt(2), 0 ], + [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], + [ 0, 1/sqrt(2), 1/sqrt(2) ] +] + +A = A.dot(deg_mat_inv_sqrt) + +[ + [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], + [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], + [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], +] + +thus: + +[ + [0.5 , 0.40824829, 0. ], + [0.40824829, 0.33333333, 0.40824829], + [0. , 0.40824829, 0.5 ] +] + +This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. + +Then, we get back to the main calculation: + +y = x * W +y = A * y + +y = x * W + +[ + [ 1.25, 0.75 ], + [ 1.5 , 1.5 ], + [ 0.25, 0.75 ] +] + +y = A * y + +[ + 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], + 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], + 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] +] + +that is: + +[ + [1.23737243, 0.98737244], + [1.11237243, 1.11237243], + [0.73737244, 0.98737244] +]. + +All checks out nicely, good. +""" + +from .dense import * +from .sparse import * +from .universal import * diff --git a/src/decagon_pytorch/convolve/dense.py b/src/decagon_pytorch/convolve/dense.py new file mode 100644 index 0000000..11e2b0d --- /dev/null +++ b/src/decagon_pytorch/convolve/dense.py @@ -0,0 +1,73 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .dropout import dropout +from .weights import init_glorot +from typing import List, Callable + + +class DenseGraphConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.mm(x, self.weight) + x = torch.mm(self.adjacency_matrix, x) + return x + + +class DenseDropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) + self.keep_prob = keep_prob + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout(x, keep_prob=self.keep_prob) + x = self.graph_conv(x) + x = self.activation(x) + return x + + +class DenseMultiDGCA(torch.nn.Module): + def __init__(self, input_dim: List[int], output_dim: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrices = adjacency_matrices + self.keep_prob = keep_prob + self.activation = activation + self.dgca = None + self.build() + + def build(self): + if len(self.input_dim) != len(self.adjacency_matrices): + raise ValueError('input_dim must have the same length as adjacency_matrices') + self.dgca = [] + for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): + self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(x, list): + raise ValueError('x must be a list of tensors') + out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) + for i, f in enumerate(self.dgca): + out += f(x[i]) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out diff --git a/src/decagon_pytorch/convolve/sparse.py b/src/decagon_pytorch/convolve/sparse.py new file mode 100644 index 0000000..7cbabf2 --- /dev/null +++ b/src/decagon_pytorch/convolve/sparse.py @@ -0,0 +1,78 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .dropout import dropout_sparse +from .weights import init_glorot +from typing import List, Callable + + +class SparseGraphConv(torch.nn.Module): + """Convolution layer for sparse inputs.""" + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.sparse.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) + return x + + +class SparseDropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrix = adjacency_matrix + self.keep_prob = keep_prob + self.activation = activation + self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout_sparse(x, self.keep_prob) + x = self.sparse_graph_conv(x) + x = self.activation(x) + return x + + +class SparseMultiDGCA(torch.nn.Module): + def __init__(self, input_dim: List[int], output_dim: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrices = adjacency_matrices + self.keep_prob = keep_prob + self.activation = activation + self.sparse_dgca = None + self.build() + + def build(self): + if len(self.input_dim) != len(self.adjacency_matrices): + raise ValueError('input_dim must have the same length as adjacency_matrices') + self.sparse_dgca = [] + for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): + self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + if not isinstance(x, list): + raise ValueError('x must be a list of tensors') + out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) + for i, f in enumerate(self.sparse_dgca): + out += f(x[i]) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out diff --git a/src/decagon_pytorch/convolve/universal.py b/src/decagon_pytorch/convolve/universal.py new file mode 100644 index 0000000..0bcc8c3 --- /dev/null +++ b/src/decagon_pytorch/convolve/universal.py @@ -0,0 +1,85 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .dropout import dropout_sparse, \ + dropout +from .weights import init_glorot +from typing import List, Callable + + +class GraphConv(torch.nn.Module): + """Convolution layer for sparse AND dense inputs.""" + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.sparse.mm(x, self.weight) \ + if x.is_sparse \ + else torch.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) \ + if self.adjacency_matrix.is_sparse \ + else torch.mm(self.adjacency_matrix, x) + return x + + +class DropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrix = adjacency_matrix + self.keep_prob = keep_prob + self.activation = activation + self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout_sparse(x, self.keep_prob) \ + if x.is_sparse \ + else dropout(x, self.keep_prob) + x = self.graph_conv(x) + x = self.activation(x) + return x + + +class MultiDGCA(torch.nn.Module): + def __init__(self, input_dim: List[int], output_dim: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrices = adjacency_matrices + self.keep_prob = keep_prob + self.activation = activation + self.dgca = None + self.build() + + def build(self): + if len(self.input_dim) != len(self.adjacency_matrices): + raise ValueError('input_dim must have the same length as adjacency_matrices') + self.dgca = [] + for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): + self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(x, list): + raise ValueError('x must be a list of tensors') + out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) + for i, f in enumerate(self.dgca): + out += f(x[i]) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out