# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # """ This module implements the basic convolutional blocks of Decagon. Just as a quick reminder, the basic convolution formula here is: y = A * (x * W) where: W is a weight matrix A is an adjacency matrix x is a matrix of latent representations of a particular type of neighbors. As we have x here twice, a trick is obviously necessary for this to work. A must be previously normalized with: c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) or c_{r}^{i} = 1/|N_{r}^{i}| Let's work through this step by step to convince ourselves that the formula is correct. x = [ [0, 1, 0, 1], [1, 1, 1, 0], [0, 0, 0, 1] ] W = [ [0, 1], [1, 0], [0.5, 0.5], [0.25, 0.75] ] A = [ [0, 1, 0], [1, 0, 1], [0, 1, 0] ] so the graph looks like this: (0) -- (1) -- (2) and therefore the representations in the next layer should be: h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + c_{r}^{1} * h_{1}^{k} h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} In actual Decagon code we can see that that latter part propagating directly the old representation is gone. I will try to do the same for now. So we have to only take care of: h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: A = A + eye(len(A)) rowsum = A.sum(1) deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) A = dot(A, deg_mat_inv_sqrt) A = A.transpose() A = A.dot(deg_mat_inv_sqrt) Let's see what gives in our case: A = A + eye(len(A)) [ [1, 1, 0], [1, 1, 1], [0, 1, 1] ] rowsum = A.sum(1) [2, 3, 2] deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) [ [1./sqrt(2), 0, 0], [0, 1./sqrt(3), 0], [0, 0, 1./sqrt(2)] ] A = dot(A, deg_mat_inv_sqrt) [ [ 1/sqrt(2), 1/sqrt(3), 0 ], [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], [ 0, 1/sqrt(3), 1/sqrt(2) ] ] A = A.transpose() [ [ 1/sqrt(2), 1/sqrt(2), 0 ], [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], [ 0, 1/sqrt(2), 1/sqrt(2) ] ] A = A.dot(deg_mat_inv_sqrt) [ [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ], [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ], [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ], ] thus: [ [0.5 , 0.40824829, 0. ], [0.40824829, 0.33333333, 0.40824829], [0. , 0.40824829, 0.5 ] ] This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. Then, we get back to the main calculation: y = x * W y = A * y y = x * W [ [ 1.25, 0.75 ], [ 1.5 , 1.5 ], [ 0.25, 0.75 ] ] y = A * y [ 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] ] that is: [ [1.23737243, 0.98737244], [1.11237243, 1.11237243], [0.73737244, 0.98737244] ]. All checks out nicely, good. """ import torch from .dropout import dropout_sparse, \ dropout from .weights import init_glorot from typing import List, Callable class SparseGraphConv(torch.nn.Module): """Convolution layer for sparse inputs.""" def __init__(self, in_channels: int, out_channels: int, adjacency_matrix: torch.Tensor, **kwargs) -> None: super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = init_glorot(in_channels, out_channels) self.adjacency_matrix = adjacency_matrix def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.sparse.mm(x, self.weight) x = torch.sparse.mm(self.adjacency_matrix, x) return x class SparseDropoutGraphConvActivation(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, adjacency_matrix: torch.Tensor, keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) self.keep_prob = keep_prob self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: x = dropout_sparse(x, self.keep_prob) x = self.sparse_graph_conv(x) x = self.activation(x) return x class SparseMultiDGCA(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: int, adjacency_matrices: List[torch.Tensor], keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim self.adjacency_matrices = adjacency_matrices self.keep_prob = keep_prob self.activation = activation self.sparse_dgca = None self.build() def build(self): if len(self.input_dim) != len(self.adjacency_matrices): raise ValueError('input_dim must have the same length as adjacency_matrices') self.sparse_dgca = [] for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) for i, f in enumerate(self.sparse_dgca): out += f(x[i]) out = torch.nn.functional.normalize(out, p=2, dim=1) return out class GraphConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, adjacency_matrix: torch.Tensor, **kwargs) -> None: super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = init_glorot(in_channels, out_channels) self.adjacency_matrix = adjacency_matrix def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.mm(x, self.weight) x = torch.mm(self.adjacency_matrix, x) return x class DropoutGraphConvActivation(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, adjacency_matrix: torch.Tensor, keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) self.keep_prob = keep_prob self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: x = dropout(x, keep_prob=self.keep_prob) x = self.graph_conv(x) x = self.activation(x) return x class MultiDGCA(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: int, adjacency_matrices: List[torch.Tensor], keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.output_dim = output_dim self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) for f in self.dgca: out += f(x) out = torch.nn.functional.normalize(out, p=2, dim=1) return out