| @@ -0,0 +1,85 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| from .dropout import dropout_sparse, \ | |||
| dropout_dense | |||
| from .weights import init_glorot | |||
| from typing import List, Callable | |||
| class GraphConv(torch.nn.Module): | |||
| """Convolution layer for sparse AND dense inputs.""" | |||
| def __init__(self, in_channels: int, out_channels: int, | |||
| adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.weight = init_glorot(in_channels, out_channels) | |||
| self.adjacency_matrix = adjacency_matrix | |||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
| x = torch.sparse.mm(x, self.weight) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weight) | |||
| x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||
| if self.adjacency_matrix.is_sparse \ | |||
| else torch.mm(self.adjacency_matrix, x) | |||
| return x | |||
| class DropoutGraphConvActivation(torch.nn.Module): | |||
| def __init__(self, input_dim: int, output_dim: int, | |||
| adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||
| activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.output_dim = output_dim | |||
| self.adjacency_matrix = adjacency_matrix | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
| x = dropout_sparse(x, self.keep_prob) \ | |||
| if x.is_sparse \ | |||
| else dropout_dense(x, self.keep_prob) | |||
| x = self.graph_conv(x) | |||
| x = self.activation(x) | |||
| return x | |||
| class MultiDGCA(torch.nn.Module): | |||
| def __init__(self, input_dim: List[int], output_dim: int, | |||
| adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||
| activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.output_dim = output_dim | |||
| self.adjacency_matrices = adjacency_matrices | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.dgca = None | |||
| self.build() | |||
| def build(self): | |||
| if len(self.input_dim) != len(self.adjacency_matrices): | |||
| raise ValueError('input_dim must have the same length as adjacency_matrices') | |||
| self.dgca = [] | |||
| for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||
| self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||
| def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||
| if not isinstance(x, list): | |||
| raise ValueError('x must be a list of tensors') | |||
| out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||
| for i, f in enumerate(self.dgca): | |||
| out += f(x[i]) | |||
| out = torch.nn.functional.normalize(out, p=2, dim=1) | |||
| return out | |||
| @@ -0,0 +1,33 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| def dropout_sparse(x, keep_prob): | |||
| x = x.coalesce() | |||
| i = x._indices() | |||
| v = x._values() | |||
| size = x.size() | |||
| n = keep_prob + torch.rand(len(v)) | |||
| n = torch.floor(n).to(torch.bool) | |||
| i = i[:,n] | |||
| v = v[n] | |||
| x = torch.sparse_coo_tensor(i, v, size=size) | |||
| return x * (1./keep_prob) | |||
| def dropout_dense(x, keep_prob): | |||
| x = x.clone().detach() | |||
| i = torch.nonzero(x) | |||
| n = keep_prob + torch.rand(len(i)) | |||
| n = (1. - torch.floor(n)).to(torch.bool) | |||
| x[i[n, 0], i[n, 1]] = 0. | |||
| return x * (1./keep_prob) | |||
| @@ -0,0 +1,19 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| import numpy as np | |||
| def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||
| """Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||
| initialization. | |||
| """ | |||
| init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||
| initial = -init_range + 2 * init_range * \ | |||
| torch.rand(( in_channels, out_channels ), dtype=dtype) | |||
| initial = initial.requires_grad_(True) | |||
| return initial | |||
| @@ -0,0 +1,94 @@ | |||
| from icosagon.convolve import GraphConv, \ | |||
| DropoutGraphConvActivation, \ | |||
| MultiDGCA | |||
| import torch | |||
| def _test_graph_conv_01(use_sparse: bool): | |||
| adj_mat = torch.rand((10, 20)) | |||
| adj_mat[adj_mat < .5] = 0 | |||
| adj_mat = torch.ceil(adj_mat) | |||
| node_reprs = torch.eye(20) | |||
| graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | |||
| if use_sparse else adj_mat) | |||
| graph_conv.weight = torch.eye(20) | |||
| res = graph_conv(node_reprs) | |||
| assert torch.all(res == adj_mat) | |||
| def _test_graph_conv_02(use_sparse: bool): | |||
| adj_mat = torch.rand((10, 20)) | |||
| adj_mat[adj_mat < .5] = 0 | |||
| adj_mat = torch.ceil(adj_mat) | |||
| node_reprs = torch.eye(20) | |||
| graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | |||
| if use_sparse else adj_mat) | |||
| graph_conv.weight = torch.eye(20) * 2 | |||
| res = graph_conv(node_reprs) | |||
| assert torch.all(res == adj_mat * 2) | |||
| def _test_graph_conv_03(use_sparse: bool): | |||
| adj_mat = torch.tensor([ | |||
| [1, 0, 1, 0, 1, 0], # [1, 0, 0] | |||
| [1, 0, 1, 0, 0, 1], # [1, 0, 0] | |||
| [1, 1, 0, 1, 0, 0], # [0, 1, 0] | |||
| [0, 0, 0, 1, 0, 1], # [0, 1, 0] | |||
| [1, 1, 1, 1, 1, 1], # [0, 0, 1] | |||
| [0, 0, 0, 1, 1, 1] # [0, 0, 1] | |||
| ], dtype=torch.float32) | |||
| expect = torch.tensor([ | |||
| [1, 1, 1], | |||
| [1, 1, 1], | |||
| [2, 1, 0], | |||
| [0, 1, 1], | |||
| [2, 2, 2], | |||
| [0, 1, 2] | |||
| ], dtype=torch.float32) | |||
| node_reprs = torch.eye(6) | |||
| graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ | |||
| if use_sparse else adj_mat) | |||
| graph_conv.weight = torch.tensor([ | |||
| [1, 0, 0], | |||
| [1, 0, 0], | |||
| [0, 1, 0], | |||
| [0, 1, 0], | |||
| [0, 0, 1], | |||
| [0, 0, 1] | |||
| ], dtype=torch.float32) | |||
| res = graph_conv(node_reprs) | |||
| assert torch.all(res == expect) | |||
| def test_graph_conv_dense_01(): | |||
| _test_graph_conv_01(use_sparse=False) | |||
| def test_graph_conv_dense_02(): | |||
| _test_graph_conv_02(use_sparse=False) | |||
| def test_graph_conv_dense_03(): | |||
| _test_graph_conv_03(use_sparse=False) | |||
| def test_graph_conv_sparse_01(): | |||
| _test_graph_conv_01(use_sparse=True) | |||
| def test_graph_conv_sparse_02(): | |||
| _test_graph_conv_02(use_sparse=True) | |||
| def test_graph_conv_sparse_03(): | |||
| _test_graph_conv_03(use_sparse=True) | |||
| @@ -0,0 +1,26 @@ | |||
| from icosagon.dropout import dropout_sparse, \ | |||
| dropout_dense | |||
| import torch | |||
| import numpy as np | |||
| def test_dropout_01(): | |||
| for i in range(11): | |||
| torch.random.manual_seed(i) | |||
| a = torch.rand((5, 10)) | |||
| a[a < .5] = 0 | |||
| keep_prob=i/10. + np.finfo(np.float32).eps | |||
| torch.random.manual_seed(i) | |||
| b = dropout_dense(a, keep_prob=keep_prob) | |||
| torch.random.manual_seed(i) | |||
| c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob) | |||
| print('keep_prob:', keep_prob) | |||
| print('a:', a.detach().cpu().numpy()) | |||
| print('b:', b.detach().cpu().numpy()) | |||
| print('c:', c, c.to_dense().detach().cpu().numpy()) | |||
| assert torch.all(b == c.to_dense()) | |||