|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import torch
- from .dropout import dropout
- from .weights import init_glorot
- from typing import List, Callable
- import pdb
-
-
- class GraphConv(torch.nn.Module):
- def __init__(self, in_channels: int, out_channels: int,
- adjacency_matrix: torch.Tensor, **kwargs) -> None:
- super().__init__(**kwargs)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
- self.adjacency_matrix = adjacency_matrix
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = torch.sparse.mm(x, self.weight) \
- if x.is_sparse \
- else torch.mm(x, self.weight)
- x = torch.sparse.mm(self.adjacency_matrix, x) \
- if self.adjacency_matrix.is_sparse \
- else torch.mm(self.adjacency_matrix, x)
- return x
-
-
- class DropoutGraphConvActivation(torch.nn.Module):
- def __init__(self, input_dim: int, output_dim: int,
- adjacency_matrix: torch.Tensor, keep_prob: float=1.,
- activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.adjacency_matrix = adjacency_matrix
- self.keep_prob = keep_prob
- self.activation = activation
- self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # pdb.set_trace()
- x = dropout(x, self.keep_prob)
- x = self.graph_conv(x)
- x = self.activation(x)
- return x
-
- def clone(self, adjacency_matrix) -> 'DropoutGraphConvActivation':
- res = DropoutGraphConvActivation(self.input_dim,
- self.output_dim, adjacency_matrix, self.keep_prob,
- self.activation)
- res.graph_conv.weight = self.graph_conv.weight
- return res
|