IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Sfoglia il codice sorgente

Add test_sparse_diag_cat_02().

master
Stanislaw Adaszewski 4 anni fa
parent
commit
5e2818fb8d
1 ha cambiato i file con 19 aggiunte e 0 eliminazioni
  1. +19
    -0
      tests/icosagon/test_fastconv.py

+ 19
- 0
tests/icosagon/test_fastconv.py Vedi File

@@ -15,3 +15,22 @@ def test_sparse_diag_cat_01():
res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
res = res.to_dense()
assert torch.all(res == ground_truth)
def test_sparse_diag_cat_02():
x = [ torch.rand(5, 10).round() for _ in range(7) ]
a = [ m.to_sparse() for m in x ]
a = _sparse_diag_cat(a)
b = torch.rand(70, 64)
res = torch.sparse.mm(a, b)
ground_truth = torch.zeros(35, 64)
ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
assert torch.all(res == ground_truth)

Loading…
Annulla
Salva