From 271fba00041f6d7d8ef82605a65afbf14a21e1ce Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 24 Jul 2020 12:27:21 +0200 Subject: [PATCH] Add test_cat_01() and test_cat_02(). --- src/icosagon/fastconv.py | 4 ++-- tests/icosagon/test_fastconv.py | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 0dd5074..74e0319 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -46,7 +46,7 @@ def _cat(matrices: List[torch.Tensor]): if n != 0 and n != len(matrices): raise ValueError('All matrices must have the same layout (dense or sparse)') - if not all(a.shape[1:] == matrices[0].shape[1:]): + if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): raise ValueError('All matrices must have the same dimensions apart from dimension 0') if not matrices[0].is_sparse: @@ -69,7 +69,7 @@ def _cat(matrices: List[torch.Tensor]): indices = torch.cat(indices).transpose(0, 1) values = torch.cat(values) - res = _sparse_coo_tensor(indices, values) + res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) return res diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py index 8aaf00d..799da5d 100644 --- a/tests/icosagon/test_fastconv.py +++ b/tests/icosagon/test_fastconv.py @@ -1,4 +1,6 @@ -from icosagon.fastconv import _sparse_diag_cat +from icosagon.fastconv import _sparse_diag_cat, \ + _cat +from icosagon.data import _equal import torch @@ -34,3 +36,25 @@ def test_sparse_diag_cat_02(): ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) assert torch.all(res == ground_truth) + + +def test_cat_01(): + matrices = [ torch.rand(5, 10) for _ in range(7) ] + res = _cat(matrices) + assert res.shape == (35, 10) + assert not res.is_sparse + ground_truth = torch.zeros(35, 10) + for i in range(7): + ground_truth[i*5:(i+1)*5, :] = matrices[i] + assert torch.all(res == ground_truth) + + +def test_cat_02(): + matrices = [ torch.rand(5, 10) for _ in range(7) ] + ground_truth = torch.zeros(35, 10) + for i in range(7): + ground_truth[i*5:(i+1)*5, :] = matrices[i] + res = _cat([ m.to_sparse() for m in matrices ]) + assert res.shape == (35, 10) + assert res.is_sparse + assert torch.all(res.to_dense() == ground_truth)