From 058d4d43fb83c6b8ada8d1b73f5077e5fb5d2679 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 24 Jul 2020 11:02:09 +0200 Subject: [PATCH] Add _sparse_diag_cat(). --- src/icosagon/fastconv.py | 33 ++++++++++++++++++++++++++++++++- tests/icosagon/test_fastconv.py | 17 +++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 tests/icosagon/test_fastconv.py diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index ba27660..0dd5074 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -5,6 +5,37 @@ from .data import Data from .trainprep import PreparedData import torch from .weights import init_glorot +from .normalize import _sparse_coo_tensor + + +def _sparse_diag_cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('The list of matrices must be non-empty') + + if not all(m.is_sparse for m in matrices): + raise ValueError('All matrices must be sparse') + + if not all(len(m.shape) == 2 for m in matrices): + raise ValueError('All matrices must be 2D') + + indices = [] + values = [] + row_offset = 0 + col_offset = 0 + + for m in matrices: + ind = m._indices().clone() + ind[0] += row_offset + ind[1] += col_offset + indices.append(ind) + values.append(m._values()) + row_offset += m.shape[0] + col_offset += m.shape[1] + + indices = torch.cat(indices, dim=1) + values = torch.cat(values) + + return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) def _cat(matrices: List[torch.Tensor]): @@ -79,7 +110,7 @@ class FastConvLayer(torch.nn.Module): output_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., - rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py new file mode 100644 index 0000000..963df9b --- /dev/null +++ b/tests/icosagon/test_fastconv.py @@ -0,0 +1,17 @@ +from icosagon.fastconv import _sparse_diag_cat +import torch + + +def test_sparse_diag_cat_01(): + matrices = [ torch.rand(5, 10).round() for _ in range(7) ] + ground_truth = torch.zeros(35, 70) + ground_truth[0:5, 0:10] = matrices[0] + ground_truth[5:10, 10:20] = matrices[1] + ground_truth[10:15, 20:30] = matrices[2] + ground_truth[15:20, 30:40] = matrices[3] + ground_truth[20:25, 40:50] = matrices[4] + ground_truth[25:30, 50:60] = matrices[5] + ground_truth[30:35, 60:70] = matrices[6] + res = _sparse_diag_cat([ m.to_sparse() for m in matrices ]) + res = res.to_dense() + assert torch.all(res == ground_truth)