from icosagon.fastconv import _sparse_diag_cat import torch def test_sparse_diag_cat_01(): matrices = [ torch.rand(5, 10).round() for _ in range(7) ] ground_truth = torch.zeros(35, 70) ground_truth[0:5, 0:10] = matrices[0] ground_truth[5:10, 10:20] = matrices[1] ground_truth[10:15, 20:30] = matrices[2] ground_truth[15:20, 30:40] = matrices[3] ground_truth[20:25, 40:50] = matrices[4] ground_truth[25:30, 50:60] = matrices[5] ground_truth[30:35, 60:70] = matrices[6] res = _sparse_diag_cat([ m.to_sparse() for m in matrices ]) res = res.to_dense() assert torch.all(res == ground_truth) def test_sparse_diag_cat_02(): x = [ torch.rand(5, 10).round() for _ in range(7) ] a = [ m.to_sparse() for m in x ] a = _sparse_diag_cat(a) b = torch.rand(70, 64) res = torch.sparse.mm(a, b) ground_truth = torch.zeros(35, 64) ground_truth[0:5, :] = torch.mm(x[0], b[0:10]) ground_truth[5:10, :] = torch.mm(x[1], b[10:20]) ground_truth[10:15, :] = torch.mm(x[2], b[20:30]) ground_truth[15:20, :] = torch.mm(x[3], b[30:40]) ground_truth[20:25, :] = torch.mm(x[4], b[40:50]) ground_truth[25:30, :] = torch.mm(x[5], b[50:60]) ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) assert torch.all(res == ground_truth)