diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index 3b116e7..d50f4b6 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -1,104 +1,272 @@ -import torch -from .dropout import dropout_sparse, \ - dropout -from .weights import init_glorot - - -class SparseGraphConv(torch.nn.Module): - """Convolution layer for sparse inputs.""" - def __init__(self, in_channels, out_channels, - adjacency_matrix, **kwargs): - super().__init__(**kwargs) - self.in_channels = in_channels - self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) - self.adjacency_matrix = adjacency_matrix - - - def forward(self, x): - x = torch.sparse.mm(x, self.weight) - x = torch.sparse.mm(self.adjacency_matrix, x) - return x - - -class SparseDropoutGraphConvActivation(torch.nn.Module): - def __init__(self, input_dim, output_dim, - adjacency_matrix, keep_prob=1., - activation=torch.nn.functional.relu, - **kwargs): - super().__init__(**kwargs) - self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) - self.keep_prob = keep_prob - self.activation = activation - - def forward(self, x): - x = dropout_sparse(x, self.keep_prob) - x = self.sparse_graph_conv(x) - x = self.activation(x) - return x - - -class SparseMultiDGCA(torch.nn.Module): - def __init__(self, input_dim, output_dim, - adjacency_matrices, keep_prob=1., - activation=torch.nn.functional.relu, - **kwargs): - super().__init__(**kwargs) - self.output_dim = output_dim - self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] - - def forward(self, x): - out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) - for f in self.sparse_dgca: - out += f(x) - out = torch.nn.functional.normalize(out, p=2, dim=1) - return out - - -class GraphConv(torch.nn.Module): - def __init__(self, in_channels, out_channels, - adjacency_matrix, **kwargs): - super().__init__(**kwargs) - self.in_channels = in_channels - self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) - self.adjacency_matrix = adjacency_matrix - - def forward(self, x): - x = torch.mm(x, self.weight) - x = torch.mm(self.adjacency_matrix, x) - return x - - -class DropoutGraphConvActivation(torch.nn.Module): - def __init__(self, input_dim, output_dim, - adjacency_matrix, keep_prob=1., - activation=torch.nn.functional.relu, - **kwargs): - super().__init__(**kwargs) - self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) - self.keep_prob = keep_prob - self.activation = activation - - def forward(self, x): - x = dropout(x, keep_prob=self.keep_prob) - x = self.graph_conv(x) - x = self.activation(x) - return x - - -class MultiDGCA(torch.nn.Module): - def __init__(self, input_dim, output_dim, - adjacency_matrices, keep_prob=1., - activation=torch.nn.functional.relu, - **kwargs): - super().__init__(**kwargs) - self.output_dim = output_dim - self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] - - def forward(self, x): - out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) - for f in self.dgca: - out += f(x) - out = torch.nn.functional.normalize(out, p=2, dim=1) - return out +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + +""" +This module implements the basic convolutional blocks of Decagon. +Just as a quick reminder, the basic convolution formula here is: + +y = A * (x * W) + +where: + +W is a weight matrix +A is an adjacency matrix +x is a matrix of latent representations of a particular type of neighbors. + +As we have x here twice, a trick is obviously necessary for this to work. +A must be previously normalized with: + +c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) + +or + +c_{r}^{i} = 1/|N_{r}^{i}| + +Let's work through this step by step to convince ourselves that the +formula is correct. + +x = [ + [0, 1, 0, 1], + [1, 1, 1, 0], + [0, 0, 0, 1] +] + +W = [ + [0, 1], + [1, 0], + [0.5, 0.5], + [0.25, 0.75] +] + +A = [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] +] + +so the graph looks like this: + +(0) -- (1) -- (2) + +and therefore the representations in the next layer should be: + +h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} +h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + + c_{r}^{1} * h_{1}^{k} +h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} + +In actual Decagon code we can see that that latter part propagating directly +the old representation is gone. I will try to do the same for now. + +So we have to only take care of: + +h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W +h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + +If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: + +A = A + eye(len(A)) +rowsum = A.sum(1) +deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) +A = dot(A, deg_mat_inv_sqrt) +A = A.transpose() +A = A.dot(deg_mat_inv_sqrt) + +Let's see what gives in our case: + +A = A + eye(len(A)) + +[ + [1, 1, 0], + [1, 1, 1], + [0, 1, 1] +] + +rowsum = A.sum(1) + +[2, 3, 2] + +deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) + +[ + [1./sqrt(2), 0, 0], + [0, 1./sqrt(3), 0], + [0, 0, 1./sqrt(2)] +] + +A = dot(A, deg_mat_inv_sqrt) + +[ + [ 1/sqrt(2), 1/sqrt(3), 0 ], + [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], + [ 0, 1/sqrt(3), 1/sqrt(2) ] +] + +A = A.transpose() + +[ + [ 1/sqrt(2), 1/sqrt(2), 0 ], + [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], + [ 0, 1/sqrt(2), 1/sqrt(2) ] +] + +A = A.dot(deg_mat_inv_sqrt) + +[ + [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], + [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], + [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], +] + +thus: + +[ + [0.5 , 0.40824829, 0. ], + [0.40824829, 0.33333333, 0.40824829], + [0. , 0.40824829, 0.5 ] +] + +This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. + +Then, we get back to the main calculation: + +y = x * W +y = A * y + +y = x * W + +[ + [ 1.25, 0.75 ], + [ 1.5 , 1.5 ], + [ 0.25, 0.75 ] +] + +y = A * y + +[ + 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], + 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], + 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] +] + +that is: + +[ + [1.23737243, 0.98737244], + [1.11237243, 1.11237243], + [0.73737244, 0.98737244] +]. + +All checks out nicely, good. + +""" + + +import torch +from .dropout import dropout_sparse, \ + dropout +from .weights import init_glorot +from typing import List, Callable + + +class SparseGraphConv(torch.nn.Module): + """Convolution layer for sparse inputs.""" + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.sparse.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) + return x + + +class SparseDropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) + self.keep_prob = keep_prob + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout_sparse(x, self.keep_prob) + x = self.sparse_graph_conv(x) + x = self.activation(x) + return x + + +class SparseMultiDGCA(torch.nn.Module): + def __init__(self, input_dim: List[int], output_dim: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.output_dim = output_dim + self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) + for f in self.sparse_dgca: + out += f(x) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out + + +class GraphConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.mm(x, self.weight) + x = torch.mm(self.adjacency_matrix, x) + return x + + +class DropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) + self.keep_prob = keep_prob + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout(x, keep_prob=self.keep_prob) + x = self.graph_conv(x) + x = self.activation(x) + return x + + +class MultiDGCA(torch.nn.Module): + def __init__(self, input_dim: List[int], output_dim: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.output_dim = output_dim + self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) + for f in self.dgca: + out += f(x) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 97c7c2a..a75c005 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -21,7 +21,7 @@ import torch -from .convole import SparseMultiDGCA +from .convolve import SparseMultiDGCA class InputLayer(torch.nn.Module):