IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_cat_01() and test_cat_02().

master
Stanislaw Adaszewski 3 years ago
parent
commit
271fba0004
2 changed files with 27 additions and 3 deletions
  1. +2
    -2
      src/icosagon/fastconv.py
  2. +25
    -1
      tests/icosagon/test_fastconv.py

+ 2
- 2
src/icosagon/fastconv.py View File

@@ -46,7 +46,7 @@ def _cat(matrices: List[torch.Tensor]):
if n != 0 and n != len(matrices):
raise ValueError('All matrices must have the same layout (dense or sparse)')
if not all(a.shape[1:] == matrices[0].shape[1:]):
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
raise ValueError('All matrices must have the same dimensions apart from dimension 0')
if not matrices[0].is_sparse:
@@ -69,7 +69,7 @@ def _cat(matrices: List[torch.Tensor]):
indices = torch.cat(indices).transpose(0, 1)
values = torch.cat(values)
res = _sparse_coo_tensor(indices, values)
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
return res


+ 25
- 1
tests/icosagon/test_fastconv.py View File

@@ -1,4 +1,6 @@
from icosagon.fastconv import _sparse_diag_cat
from icosagon.fastconv import _sparse_diag_cat, \
_cat
from icosagon.data import _equal
import torch
@@ -34,3 +36,25 @@ def test_sparse_diag_cat_02():
ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
assert torch.all(res == ground_truth)
def test_cat_01():
matrices = [ torch.rand(5, 10) for _ in range(7) ]
res = _cat(matrices)
assert res.shape == (35, 10)
assert not res.is_sparse
ground_truth = torch.zeros(35, 10)
for i in range(7):
ground_truth[i*5:(i+1)*5, :] = matrices[i]
assert torch.all(res == ground_truth)
def test_cat_02():
matrices = [ torch.rand(5, 10) for _ in range(7) ]
ground_truth = torch.zeros(35, 10)
for i in range(7):
ground_truth[i*5:(i+1)*5, :] = matrices[i]
res = _cat([ m.to_sparse() for m in matrices ])
assert res.shape == (35, 10)
assert res.is_sparse
assert torch.all(res.to_dense() == ground_truth)

Loading…
Cancel
Save