IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Refactor convolve into 3 separate modules.

master
Stanislaw Adaszewski 3 years ago
parent
commit
c74555fef5
5 changed files with 404 additions and 346 deletions
  1. +0
    -346
      src/decagon_pytorch/convolve.py
  2. +168
    -0
      src/decagon_pytorch/convolve/__init__.py
  3. +73
    -0
      src/decagon_pytorch/convolve/dense.py
  4. +78
    -0
      src/decagon_pytorch/convolve/sparse.py
  5. +85
    -0
      src/decagon_pytorch/convolve/universal.py

+ 0
- 346
src/decagon_pytorch/convolve.py View File

@@ -1,346 +0,0 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#

"""
This module implements the basic convolutional blocks of Decagon.
Just as a quick reminder, the basic convolution formula here is:

y = A * (x * W)

where:

W is a weight matrix
A is an adjacency matrix
x is a matrix of latent representations of a particular type of neighbors.

As we have x here twice, a trick is obviously necessary for this to work.
A must be previously normalized with:

c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)

or

c_{r}^{i} = 1/|N_{r}^{i}|

Let's work through this step by step to convince ourselves that the
formula is correct.

x = [
[0, 1, 0, 1],
[1, 1, 1, 0],
[0, 0, 0, 1]
]

W = [
[0, 1],
[1, 0],
[0.5, 0.5],
[0.25, 0.75]
]

A = [
[0, 1, 0],
[1, 0, 1],
[0, 1, 0]
]

so the graph looks like this:

(0) -- (1) -- (2)

and therefore the representations in the next layer should be:

h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
c_{r}^{1} * h_{1}^{k}
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}

In actual Decagon code we can see that that latter part propagating directly
the old representation is gone. I will try to do the same for now.

So we have to only take care of:

h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W

If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:

A = A + eye(len(A))
rowsum = A.sum(1)
deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
A = dot(A, deg_mat_inv_sqrt)
A = A.transpose()
A = A.dot(deg_mat_inv_sqrt)

Let's see what gives in our case:

A = A + eye(len(A))

[
[1, 1, 0],
[1, 1, 1],
[0, 1, 1]
]

rowsum = A.sum(1)

[2, 3, 2]

deg_mat_inv_sqrt = diags(power(rowsum, -0.5))

[
[1./sqrt(2), 0, 0],
[0, 1./sqrt(3), 0],
[0, 0, 1./sqrt(2)]
]

A = dot(A, deg_mat_inv_sqrt)

[
[ 1/sqrt(2), 1/sqrt(3), 0 ],
[ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
[ 0, 1/sqrt(3), 1/sqrt(2) ]
]

A = A.transpose()

[
[ 1/sqrt(2), 1/sqrt(2), 0 ],
[ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
[ 0, 1/sqrt(2), 1/sqrt(2) ]
]

A = A.dot(deg_mat_inv_sqrt)

[
[ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
[ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
[ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
]

thus:

[
[0.5 , 0.40824829, 0. ],
[0.40824829, 0.33333333, 0.40824829],
[0. , 0.40824829, 0.5 ]
]

This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.

Then, we get back to the main calculation:

y = x * W
y = A * y

y = x * W

[
[ 1.25, 0.75 ],
[ 1.5 , 1.5 ],
[ 0.25, 0.75 ]
]

y = A * y

[
0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
]

that is:

[
[1.23737243, 0.98737244],
[1.11237243, 1.11237243],
[0.73737244, 0.98737244]
].

All checks out nicely, good.

"""


import torch
from .dropout import dropout_sparse, \
dropout
from .weights import init_glorot
from typing import List, Callable


class SparseGraphConv(torch.nn.Module):
"""Convolution layer for sparse inputs."""
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix


def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.sparse.mm(x, self.weight)
x = torch.sparse.mm(self.adjacency_matrix, x)
return x


class SparseDropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrix = adjacency_matrix
self.keep_prob = keep_prob
self.activation = activation
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout_sparse(x, self.keep_prob)
x = self.sparse_graph_conv(x)
x = self.activation(x)
return x


class SparseMultiDGCA(torch.nn.Module):
def __init__(self, input_dim: List[int], output_dim: int,
adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.sparse_dgca = None
self.build()

def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.sparse_dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))

def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.sparse_dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out


class DenseGraphConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.mm(x, self.weight)
x = torch.mm(self.adjacency_matrix, x)
return x


class DenseDropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix)
self.keep_prob = keep_prob
self.activation = activation

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout(x, keep_prob=self.keep_prob)
x = self.graph_conv(x)
x = self.activation(x)
return x


class DenseMultiDGCA(torch.nn.Module):
def __init__(self, input_dim: List[int], output_dim: int,
adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.dgca = None
self.build()

def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))

def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out


class GraphConv(torch.nn.Module):
"""Convolution layer for sparse AND dense inputs."""
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix


def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.sparse.mm(x, self.weight) \
if x.is_sparse \
else torch.mm(x, self.weight)
x = torch.sparse.mm(self.adjacency_matrix, x) \
if self.adjacency_matrix.is_sparse \
else torch.mm(self.adjacency_matrix, x)
return x


class DropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrix = adjacency_matrix
self.keep_prob = keep_prob
self.activation = activation
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout_sparse(x, self.keep_prob) \
if x.is_sparse \
else dropout(x, self.keep_prob)
x = self.graph_conv(x)
x = self.activation(x)
return x

+ 168
- 0
src/decagon_pytorch/convolve/__init__.py View File

@@ -0,0 +1,168 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
"""
This module implements the basic convolutional blocks of Decagon.
Just as a quick reminder, the basic convolution formula here is:
y = A * (x * W)
where:
W is a weight matrix
A is an adjacency matrix
x is a matrix of latent representations of a particular type of neighbors.
As we have x here twice, a trick is obviously necessary for this to work.
A must be previously normalized with:
c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
or
c_{r}^{i} = 1/|N_{r}^{i}|
Let's work through this step by step to convince ourselves that the
formula is correct.
x = [
[0, 1, 0, 1],
[1, 1, 1, 0],
[0, 0, 0, 1]
]
W = [
[0, 1],
[1, 0],
[0.5, 0.5],
[0.25, 0.75]
]
A = [
[0, 1, 0],
[1, 0, 1],
[0, 1, 0]
]
so the graph looks like this:
(0) -- (1) -- (2)
and therefore the representations in the next layer should be:
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
c_{r}^{1} * h_{1}^{k}
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
In actual Decagon code we can see that that latter part propagating directly
the old representation is gone. I will try to do the same for now.
So we have to only take care of:
h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
A = A + eye(len(A))
rowsum = A.sum(1)
deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
A = dot(A, deg_mat_inv_sqrt)
A = A.transpose()
A = A.dot(deg_mat_inv_sqrt)
Let's see what gives in our case:
A = A + eye(len(A))
[
[1, 1, 0],
[1, 1, 1],
[0, 1, 1]
]
rowsum = A.sum(1)
[2, 3, 2]
deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
[
[1./sqrt(2), 0, 0],
[0, 1./sqrt(3), 0],
[0, 0, 1./sqrt(2)]
]
A = dot(A, deg_mat_inv_sqrt)
[
[ 1/sqrt(2), 1/sqrt(3), 0 ],
[ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
[ 0, 1/sqrt(3), 1/sqrt(2) ]
]
A = A.transpose()
[
[ 1/sqrt(2), 1/sqrt(2), 0 ],
[ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
[ 0, 1/sqrt(2), 1/sqrt(2) ]
]
A = A.dot(deg_mat_inv_sqrt)
[
[ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
[ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
[ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
]
thus:
[
[0.5 , 0.40824829, 0. ],
[0.40824829, 0.33333333, 0.40824829],
[0. , 0.40824829, 0.5 ]
]
This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
Then, we get back to the main calculation:
y = x * W
y = A * y
y = x * W
[
[ 1.25, 0.75 ],
[ 1.5 , 1.5 ],
[ 0.25, 0.75 ]
]
y = A * y
[
0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
]
that is:
[
[1.23737243, 0.98737244],
[1.11237243, 1.11237243],
[0.73737244, 0.98737244]
].
All checks out nicely, good.
"""
from .dense import *
from .sparse import *
from .universal import *

+ 73
- 0
src/decagon_pytorch/convolve/dense.py View File

@@ -0,0 +1,73 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import torch
from .dropout import dropout
from .weights import init_glorot
from typing import List, Callable
class DenseGraphConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.mm(x, self.weight)
x = torch.mm(self.adjacency_matrix, x)
return x
class DenseDropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix)
self.keep_prob = keep_prob
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout(x, keep_prob=self.keep_prob)
x = self.graph_conv(x)
x = self.activation(x)
return x
class DenseMultiDGCA(torch.nn.Module):
def __init__(self, input_dim: List[int], output_dim: int,
adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.dgca = None
self.build()
def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out

+ 78
- 0
src/decagon_pytorch/convolve/sparse.py View File

@@ -0,0 +1,78 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import torch
from .dropout import dropout_sparse
from .weights import init_glorot
from typing import List, Callable
class SparseGraphConv(torch.nn.Module):
"""Convolution layer for sparse inputs."""
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.sparse.mm(x, self.weight)
x = torch.sparse.mm(self.adjacency_matrix, x)
return x
class SparseDropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrix = adjacency_matrix
self.keep_prob = keep_prob
self.activation = activation
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout_sparse(x, self.keep_prob)
x = self.sparse_graph_conv(x)
x = self.activation(x)
return x
class SparseMultiDGCA(torch.nn.Module):
def __init__(self, input_dim: List[int], output_dim: int,
adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.sparse_dgca = None
self.build()
def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.sparse_dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.sparse_dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out

+ 85
- 0
src/decagon_pytorch/convolve/universal.py View File

@@ -0,0 +1,85 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import torch
from .dropout import dropout_sparse, \
dropout
from .weights import init_glorot
from typing import List, Callable
class GraphConv(torch.nn.Module):
"""Convolution layer for sparse AND dense inputs."""
def __init__(self, in_channels: int, out_channels: int,
adjacency_matrix: torch.Tensor, **kwargs) -> None:
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.adjacency_matrix = adjacency_matrix
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.sparse.mm(x, self.weight) \
if x.is_sparse \
else torch.mm(x, self.weight)
x = torch.sparse.mm(self.adjacency_matrix, x) \
if self.adjacency_matrix.is_sparse \
else torch.mm(self.adjacency_matrix, x)
return x
class DropoutGraphConvActivation(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int,
adjacency_matrix: torch.Tensor, keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrix = adjacency_matrix
self.keep_prob = keep_prob
self.activation = activation
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout_sparse(x, self.keep_prob) \
if x.is_sparse \
else dropout(x, self.keep_prob)
x = self.graph_conv(x)
x = self.activation(x)
return x
class MultiDGCA(torch.nn.Module):
def __init__(self, input_dim: List[int], output_dim: int,
adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.dgca = None
self.build()
def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out

Loading…
Cancel
Save