@@ -46,7 +46,7 @@ def _cat(matrices: List[torch.Tensor]): | |||||
if n != 0 and n != len(matrices): | if n != 0 and n != len(matrices): | ||||
raise ValueError('All matrices must have the same layout (dense or sparse)') | raise ValueError('All matrices must have the same layout (dense or sparse)') | ||||
if not all(a.shape[1:] == matrices[0].shape[1:]): | |||||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | raise ValueError('All matrices must have the same dimensions apart from dimension 0') | ||||
if not matrices[0].is_sparse: | if not matrices[0].is_sparse: | ||||
@@ -69,7 +69,7 @@ def _cat(matrices: List[torch.Tensor]): | |||||
indices = torch.cat(indices).transpose(0, 1) | indices = torch.cat(indices).transpose(0, 1) | ||||
values = torch.cat(values) | values = torch.cat(values) | ||||
res = _sparse_coo_tensor(indices, values) | |||||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||||
return res | return res | ||||
@@ -1,4 +1,6 @@ | |||||
from icosagon.fastconv import _sparse_diag_cat | |||||
from icosagon.fastconv import _sparse_diag_cat, \ | |||||
_cat | |||||
from icosagon.data import _equal | |||||
import torch | import torch | ||||
@@ -34,3 +36,25 @@ def test_sparse_diag_cat_02(): | |||||
ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) | ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) | ||||
assert torch.all(res == ground_truth) | assert torch.all(res == ground_truth) | ||||
def test_cat_01(): | |||||
matrices = [ torch.rand(5, 10) for _ in range(7) ] | |||||
res = _cat(matrices) | |||||
assert res.shape == (35, 10) | |||||
assert not res.is_sparse | |||||
ground_truth = torch.zeros(35, 10) | |||||
for i in range(7): | |||||
ground_truth[i*5:(i+1)*5, :] = matrices[i] | |||||
assert torch.all(res == ground_truth) | |||||
def test_cat_02(): | |||||
matrices = [ torch.rand(5, 10) for _ in range(7) ] | |||||
ground_truth = torch.zeros(35, 10) | |||||
for i in range(7): | |||||
ground_truth[i*5:(i+1)*5, :] = matrices[i] | |||||
res = _cat([ m.to_sparse() for m in matrices ]) | |||||
assert res.shape == (35, 10) | |||||
assert res.is_sparse | |||||
assert torch.all(res.to_dense() == ground_truth) |