# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # import torch from ..dropout import dropout from ..weights import init_glorot from typing import List, Callable class DenseGraphConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, adjacency_matrix: torch.Tensor, **kwargs) -> None: super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = init_glorot(in_channels, out_channels) self.adjacency_matrix = adjacency_matrix def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.mm(x, self.weight) x = torch.mm(self.adjacency_matrix, x) return x class DenseDropoutGraphConvActivation(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, adjacency_matrix: torch.Tensor, keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) self.keep_prob = keep_prob self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: x = dropout(x, keep_prob=self.keep_prob) x = self.graph_conv(x) x = self.activation(x) return x class DenseMultiDGCA(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: int, adjacency_matrices: List[torch.Tensor], keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim self.adjacency_matrices = adjacency_matrices self.keep_prob = keep_prob self.activation = activation self.dgca = None self.build() def build(self): if len(self.input_dim) != len(self.adjacency_matrices): raise ValueError('input_dim must have the same length as adjacency_matrices') self.dgca = [] for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if not isinstance(x, list): raise ValueError('x must be a list of tensors') out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) for i, f in enumerate(self.dgca): out += f(x[i]) out = torch.nn.functional.normalize(out, p=2, dim=1) return out