|
|
@@ -1,4 +1,6 @@ |
|
|
|
from icosagon.fastconv import _sparse_diag_cat
|
|
|
|
from icosagon.fastconv import _sparse_diag_cat, \
|
|
|
|
_cat
|
|
|
|
from icosagon.data import _equal
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
@@ -34,3 +36,25 @@ def test_sparse_diag_cat_02(): |
|
|
|
ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
|
|
|
|
|
|
|
|
assert torch.all(res == ground_truth)
|
|
|
|
|
|
|
|
|
|
|
|
def test_cat_01():
|
|
|
|
matrices = [ torch.rand(5, 10) for _ in range(7) ]
|
|
|
|
res = _cat(matrices)
|
|
|
|
assert res.shape == (35, 10)
|
|
|
|
assert not res.is_sparse
|
|
|
|
ground_truth = torch.zeros(35, 10)
|
|
|
|
for i in range(7):
|
|
|
|
ground_truth[i*5:(i+1)*5, :] = matrices[i]
|
|
|
|
assert torch.all(res == ground_truth)
|
|
|
|
|
|
|
|
|
|
|
|
def test_cat_02():
|
|
|
|
matrices = [ torch.rand(5, 10) for _ in range(7) ]
|
|
|
|
ground_truth = torch.zeros(35, 10)
|
|
|
|
for i in range(7):
|
|
|
|
ground_truth[i*5:(i+1)*5, :] = matrices[i]
|
|
|
|
res = _cat([ m.to_sparse() for m in matrices ])
|
|
|
|
assert res.shape == (35, 10)
|
|
|
|
assert res.is_sparse
|
|
|
|
assert torch.all(res.to_dense() == ground_truth)
|