IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.7KB

  1. from triacontagon.util import \
  2. _clear_adjacency_matrix_except_rows, \
  3. _sparse_diag_cat, \
  4. _equal
  5. import torch
  6. import time
  7. def test_clear_adjacency_matrix_except_rows_01():
  8. adj_mat = torch.tensor([
  9. [0, 0, 1, 0, 0],
  10. [0, 0, 0, 1, 1],
  11. [1, 0, 1, 0, 0],
  12. [1, 1, 0, 0, 0]
  13. ], dtype=torch.uint8).to_sparse()
  14. adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
  15. res = _clear_adjacency_matrix_except_rows(adj_mat,
  16. torch.tensor([1, 3]), 4, 2)
  17. res = res.to_dense()
  18. truth = torch.tensor([
  19. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  20. [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
  21. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  22. [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  23. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  24. [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
  25. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  26. [0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
  27. ], dtype=torch.uint8)
  28. print('res:', res)
  29. assert torch.all(res == truth)
  30. def test_clear_adjacency_matrix_except_rows_02():
  31. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  32. t = time.time()
  33. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  34. print('_sparse_diag_cat() took:', time.time() - t)
  35. t = time.time()
  36. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  37. 6, 130)
  38. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  39. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  40. torch.zeros(10)
  41. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  42. assert _equal(res, truth).all()
  43. def test_clear_adjacency_matrix_except_rows_03():
  44. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  45. t = time.time()
  46. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  47. print('_sparse_diag_cat() took:', time.time() - t)
  48. t = time.time()
  49. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  50. 6, 1300)
  51. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  52. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  53. torch.zeros(10)
  54. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  55. assert _equal(res, truth).all()
  56. def test_clear_adjacency_matrix_except_rows_04():
  57. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
  58. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  59. t = time.time()
  60. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  61. print('_sparse_diag_cat() took:', time.time() - t)
  62. t = time.time()
  63. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  64. 2000, 1300)
  65. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  66. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  67. torch.zeros(2000)
  68. adj_mat[6:] = torch.zeros(2000)
  69. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  70. assert _equal(res, truth).all()
  71. def test_clear_adjacency_matrix_except_rows_05():
  72. if torch.cuda.device_count() == 0:
  73. pytest.skip('Test requires CUDA')
  74. device = torch.device('cuda:0')
  75. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device)
  76. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  77. t = time.time()
  78. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  79. print('_sparse_diag_cat() took:', time.time() - t)
  80. rows = torch.tensor(list(range(512)), device=device)
  81. t = time.time()
  82. res = _clear_adjacency_matrix_except_rows(res, rows,
  83. 2000, 1300)
  84. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  85. adj_mat[512:] = torch.zeros(2000)
  86. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  87. assert _equal(res, truth).all()