- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
- """
- This module implements the basic convolutional blocks of Decagon.
- Just as a quick reminder, the basic convolution formula here is:
- y = A * (x * W)
- where:
- W is a weight matrix
- A is an adjacency matrix
- x is a matrix of latent representations of a particular type of neighbors.
- As we have x here twice, a trick is obviously necessary for this to work.
- A must be previously normalized with:
- c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
- or
- c_{r}^{i} = 1/|N_{r}^{i}|
- Let's work through this step by step to convince ourselves that the
- formula is correct.
- x = [
- [0, 1, 0, 1],
- [1, 1, 1, 0],
- [0, 0, 0, 1]
- ]
- W = [
- [0, 1],
- [1, 0],
- [0.5, 0.5],
- [0.25, 0.75]
- ]
- A = [
- [0, 1, 0],
- [1, 0, 1],
- [0, 1, 0]
- ]
- so the graph looks like this:
- (0) -- (1) -- (2)
- and therefore the representations in the next layer should be:
- h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
- h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
- c_{r}^{1} * h_{1}^{k}
- h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
- In actual Decagon code we can see that that latter part propagating directly
- the old representation is gone. I will try to do the same for now.
- So we have to only take care of:
- h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
- h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
- h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
- If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
- A = A + eye(len(A))
- rowsum = A.sum(1)
- deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
- A = dot(A, deg_mat_inv_sqrt)
- A = A.transpose()
- A = A.dot(deg_mat_inv_sqrt)
- Let's see what gives in our case:
- A = A + eye(len(A))
- [
- [1, 1, 0],
- [1, 1, 1],
- [0, 1, 1]
- ]
- rowsum = A.sum(1)
- [2, 3, 2]
- deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
- [
- [1./sqrt(2), 0, 0],
- [0, 1./sqrt(3), 0],
- [0, 0, 1./sqrt(2)]
- ]
- A = dot(A, deg_mat_inv_sqrt)
- [
- [ 1/sqrt(2), 1/sqrt(3), 0 ],
- [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
- [ 0, 1/sqrt(3), 1/sqrt(2) ]
- ]
- A = A.transpose()
- [
- [ 1/sqrt(2), 1/sqrt(2), 0 ],
- [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
- [ 0, 1/sqrt(2), 1/sqrt(2) ]
- ]
- A = A.dot(deg_mat_inv_sqrt)
- [
- [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
- [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
- [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
- ]
- thus:
- [
- [0.5 , 0.40824829, 0. ],
- [0.40824829, 0.33333333, 0.40824829],
- [0. , 0.40824829, 0.5 ]
- ]
- This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
- Then, we get back to the main calculation:
- y = x * W
- y = A * y
- y = x * W
- [
- [ 1.25, 0.75 ],
- [ 1.5 , 1.5 ],
- [ 0.25, 0.75 ]
- ]
- y = A * y
- [
- 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
- 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
- 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
- ]
- that is:
- [
- [1.23737243, 0.98737244],
- [1.11237243, 1.11237243],
- [0.73737244, 0.98737244]
- ].
- All checks out nicely, good.
- """
- import torch
- from .dropout import dropout_sparse, \
- dropout
- from .weights import init_glorot
- from typing import List, Callable
- class SparseGraphConv(torch.nn.Module):
- """Convolution layer for sparse inputs."""
- def __init__(self, in_channels: int, out_channels: int,
- adjacency_matrix: torch.Tensor, **kwargs) -> None:
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = init_glorot(in_channels, out_channels)
- self.adjacency_matrix = adjacency_matrix
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = torch.sparse.mm(x, self.weight)
- x = torch.sparse.mm(self.adjacency_matrix, x)
- return x
- class SparseDropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim: int, output_dim: int,
- adjacency_matrix: torch.Tensor, keep_prob: float=1.,
- activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
- self.keep_prob = keep_prob
- self.activation = activation
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = dropout_sparse(x, self.keep_prob)
- x = self.sparse_graph_conv(x)
- x = self.activation(x)
- return x
- class SparseMultiDGCA(torch.nn.Module):
- def __init__(self, input_dim: List[int], output_dim: int,
- adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
- activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.adjacency_matrices = adjacency_matrices
- self.keep_prob = keep_prob
- self.activation = activation
- self.sparse_dgca = None
- self.build()
- def build(self):
- if len(self.input_dim) != len(self.adjacency_matrices):
- raise ValueError('input_dim must have the same length as adjacency_matrices')
- self.sparse_dgca = []
- for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
- self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
- if not isinstance(x, list):
- raise ValueError('x must be a list of tensors')
- out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
- for i, f in enumerate(self.sparse_dgca):
- out += f(x[i])
- out = torch.nn.functional.normalize(out, p=2, dim=1)
- return out
- class GraphConv(torch.nn.Module):
- def __init__(self, in_channels: int, out_channels: int,
- adjacency_matrix: torch.Tensor, **kwargs) -> None:
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = init_glorot(in_channels, out_channels)
- self.adjacency_matrix = adjacency_matrix
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = torch.mm(x, self.weight)
- x = torch.mm(self.adjacency_matrix, x)
- return x
- class DropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim: int, output_dim: int,
- adjacency_matrix: torch.Tensor, keep_prob: float=1.,
- activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
- self.keep_prob = keep_prob
- self.activation = activation
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = dropout(x, keep_prob=self.keep_prob)
- x = self.graph_conv(x)
- x = self.activation(x)
- return x
- class MultiDGCA(torch.nn.Module):
- def __init__(self, input_dim: List[int], output_dim: int,
- adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
- activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.adjacency_matrices = adjacency_matrices
- self.keep_prob = keep_prob
- self.activation = activation
- self.dgca = None
- self.build()
- def build(self):
- if len(self.input_dim) != len(self.adjacency_matrices):
- raise ValueError('input_dim must have the same length as adjacency_matrices')
- self.dgca = []
- for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
- self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
- def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
- if not isinstance(x, list):
- raise ValueError('x must be a list of tensors')
- out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
- for i, f in enumerate(self.dgca):
- out += f(x[i])
- out = torch.nn.functional.normalize(out, p=2, dim=1)
- return out