IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
4.7KB

  1. from triacontagon.util import \
  2. _clear_adjacency_matrix_except_rows, \
  3. _sparse_diag_cat, \
  4. _equal, \
  5. _per_layer_required_vertices
  6. from triacontagon.model import TrainingBatch
  7. from triacontagon.decode import dedicom_decoder
  8. from triacontagon.data import Data
  9. import torch
  10. import time
  11. def test_clear_adjacency_matrix_except_rows_01():
  12. adj_mat = torch.tensor([
  13. [0, 0, 1, 0, 0],
  14. [0, 0, 0, 1, 1],
  15. [1, 0, 1, 0, 0],
  16. [1, 1, 0, 0, 0]
  17. ], dtype=torch.uint8).to_sparse()
  18. adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
  19. res = _clear_adjacency_matrix_except_rows(adj_mat,
  20. torch.tensor([1, 3]), 4, 2)
  21. res = res.to_dense()
  22. truth = torch.tensor([
  23. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  24. [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
  25. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  26. [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  27. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  28. [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
  29. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  30. [0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
  31. ], dtype=torch.uint8)
  32. print('res:', res)
  33. assert torch.all(res == truth)
  34. def test_clear_adjacency_matrix_except_rows_02():
  35. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  36. t = time.time()
  37. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  38. print('_sparse_diag_cat() took:', time.time() - t)
  39. t = time.time()
  40. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  41. 6, 130)
  42. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  43. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  44. torch.zeros(10)
  45. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
  46. assert _equal(res, truth).all()
  47. def test_clear_adjacency_matrix_except_rows_03():
  48. adj_mat = torch.rand(6, 10).round().to(torch.uint8)
  49. t = time.time()
  50. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  51. print('_sparse_diag_cat() took:', time.time() - t)
  52. t = time.time()
  53. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  54. 6, 1300)
  55. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  56. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  57. torch.zeros(10)
  58. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  59. assert _equal(res, truth).all()
  60. def test_clear_adjacency_matrix_except_rows_04():
  61. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
  62. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  63. t = time.time()
  64. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  65. print('_sparse_diag_cat() took:', time.time() - t)
  66. t = time.time()
  67. res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
  68. 2000, 1300)
  69. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  70. adj_mat[0] = adj_mat[2] = adj_mat[4] = \
  71. torch.zeros(2000)
  72. adj_mat[6:] = torch.zeros(2000)
  73. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  74. assert _equal(res, truth).all()
  75. def test_clear_adjacency_matrix_except_rows_05():
  76. if torch.cuda.device_count() == 0:
  77. pytest.skip('Test requires CUDA')
  78. device = torch.device('cuda:0')
  79. adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device)
  80. print('adj_mat.to_sparse():', adj_mat.to_sparse())
  81. t = time.time()
  82. res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  83. print('_sparse_diag_cat() took:', time.time() - t)
  84. rows = torch.tensor(list(range(512)), device=device)
  85. t = time.time()
  86. res = _clear_adjacency_matrix_except_rows(res, rows,
  87. 2000, 1300)
  88. print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
  89. adj_mat[512:] = torch.zeros(2000)
  90. truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
  91. assert _equal(res, truth).all()
  92. def test_per_layer_required_vertices_01():
  93. d = Data()
  94. d.add_vertex_type('Gene', 4)
  95. d.add_vertex_type('Drug', 5)
  96. d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
  97. [1, 0, 0, 1],
  98. [0, 1, 1, 0],
  99. [0, 0, 1, 0],
  100. [0, 1, 0, 1]
  101. ]).to_sparse() ], dedicom_decoder)
  102. d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
  103. [0, 1, 0, 0, 1],
  104. [0, 0, 1, 0, 0],
  105. [1, 0, 0, 0, 1],
  106. [0, 0, 1, 1, 0]
  107. ]).to_sparse() ], dedicom_decoder)
  108. d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
  109. [1, 0, 0, 0, 0],
  110. [0, 1, 0, 0, 0],
  111. [0, 0, 1, 0, 0],
  112. [0, 0, 0, 1, 0],
  113. [0, 0, 0, 0, 1]
  114. ]).to_sparse() ], dedicom_decoder)
  115. batch = TrainingBatch(0, 1, 0, torch.tensor([
  116. [0, 1]
  117. ]))
  118. res = _per_layer_required_vertices(d, batch, 5)
  119. print('res:', res)