IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.1KB

  1. from triacontagon.util import \
  2. _clear_adjacency_matrix_except_rows, \
  3. _sparse_diag_cat, \
  4. _equal, \
  5. common_one_hot_encoding
  6. from triacontagon.model import TrainingBatch
  7. from triacontagon.decode import dedicom_decoder
  8. from triacontagon.data import Data
  9. import torch
  10. import time
  11. def test_clear_adjacency_matrix_except_rows_01():
  12. adj_mat = torch.tensor([
  13. [0, 0, 1, 0, 0],
  14. [0, 0, 0, 1, 1],
  15. [1, 0, 1, 0, 0],
  16. [1, 1, 0, 0, 0]
  17. ], dtype=torch.uint8).to_sparse()
  18. adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
  19. res = _clear_adjacency_matrix_except_rows(adj_mat,
  20. torch.tensor([1, 3]), 4, 2)
  21. res = res.to_dense()
  22. truth = torch.tensor([
  23. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  24. [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
  25. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  26. [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  27. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  28. [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
  29. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  30. [0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
  31. ], dtype=torch.uint8)
  32. print('res:', res)
  33. assert torch.all(res == truth)
  34. def test_clear_adjacency_matrix_except_rows_02():
  35. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  36. t = time.time()
  37. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  38. print('_sparse_diag_cat() took:', time.time() - t)
  39. t = time.time()
  40. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  41. 6, 130)
  42. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  43. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  44. torch.zeros(10)
  45. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  46. assert _equal(res, truth).all()
  47. def test_clear_adjacency_matrix_except_rows_03():
  48. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  49. t = time.time()
  50. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  51. print('_sparse_diag_cat() took:', time.time() - t)
  52. t = time.time()
  53. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  54. 6, 1300)
  55. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  56. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  57. torch.zeros(10)
  58. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  59. assert _equal(res, truth).all()
  60. def test_clear_adjacency_matrix_except_rows_04():
  61. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
  62. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  63. t = time.time()
  64. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  65. print('_sparse_diag_cat() took:', time.time() - t)
  66. t = time.time()
  67. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  68. 2000, 1300)
  69. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  70. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  71. torch.zeros(2000)
  72. adj_mat[6:] = torch.zeros(2000)
  73. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  74. assert _equal(res, truth).all()
  75. def test_clear_adjacency_matrix_except_rows_05():
  76. if torch.cuda.device_count() == 0:
  77. pytest.skip('Test requires CUDA')
  78. device = torch.device('cuda:0')
  79. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device)
  80. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  81. t = time.time()
  82. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  83. print('_sparse_diag_cat() took:', time.time() - t)
  84. rows = torch.tensor(list(range(512)), device=device)
  85. t = time.time()
  86. res = _clear_adjacency_matrix_except_rows(res, rows,
  87. 2000, 1300)
  88. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  89. adj_mat[512:] = torch.zeros(2000)
  90. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  91. assert _equal(res, truth).all()
  92. def test_common_one_hot_encoding_01():
  93. in_repr = common_one_hot_encoding([2000, 200])
  94. ref = torch.eye(2200)
  95. assert torch.all(in_repr[0].to_dense() == ref[:2000, :])
  96. assert torch.all(in_repr[1].to_dense() == ref[2000:, :])