IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

37 líneas
1.3KB

  1. from icosagon.fastconv import _sparse_diag_cat
  2. import torch
  3. def test_sparse_diag_cat_01():
  4. matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
  5. ground_truth = torch.zeros(35, 70)
  6. ground_truth[0:5, 0:10] = matrices[0]
  7. ground_truth[5:10, 10:20] = matrices[1]
  8. ground_truth[10:15, 20:30] = matrices[2]
  9. ground_truth[15:20, 30:40] = matrices[3]
  10. ground_truth[20:25, 40:50] = matrices[4]
  11. ground_truth[25:30, 50:60] = matrices[5]
  12. ground_truth[30:35, 60:70] = matrices[6]
  13. res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
  14. res = res.to_dense()
  15. assert torch.all(res == ground_truth)
  16. def test_sparse_diag_cat_02():
  17. x = [ torch.rand(5, 10).round() for _ in range(7) ]
  18. a = [ m.to_sparse() for m in x ]
  19. a = _sparse_diag_cat(a)
  20. b = torch.rand(70, 64)
  21. res = torch.sparse.mm(a, b)
  22. ground_truth = torch.zeros(35, 64)
  23. ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
  24. ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
  25. ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
  26. ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
  27. ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
  28. ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
  29. ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
  30. assert torch.all(res == ground_truth)