|
- from icosagon.fastconv import _sparse_diag_cat
- import torch
-
-
- def test_sparse_diag_cat_01():
- matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
- ground_truth = torch.zeros(35, 70)
- ground_truth[0:5, 0:10] = matrices[0]
- ground_truth[5:10, 10:20] = matrices[1]
- ground_truth[10:15, 20:30] = matrices[2]
- ground_truth[15:20, 30:40] = matrices[3]
- ground_truth[20:25, 40:50] = matrices[4]
- ground_truth[25:30, 50:60] = matrices[5]
- ground_truth[30:35, 60:70] = matrices[6]
- res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
- res = res.to_dense()
- assert torch.all(res == ground_truth)
|