IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
3.6KB

  1. from triacontagon.data import Data
  2. from triacontagon.sampling import fixed_unigram_candidate_sampler, \
  3. get_true_classes, \
  4. negative_sample_adj_mat, \
  5. negative_sample_data, \
  6. get_edges_and_degrees
  7. import triacontagon.sampling
  8. from triacontagon.decode import dedicom_decoder
  9. import torch
  10. import time
  11. import pytest
  12. def test_fixed_unigram_candidate_sampler_01():
  13. true_classes = torch.tensor([[-1],
  14. [-1],
  15. [ 3],
  16. [ 2],
  17. [-1]])
  18. num_repeats = torch.tensor([0, 0, 1, 1, 0])
  19. unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64)
  20. distortion = 0.75
  21. res = fixed_unigram_candidate_sampler(true_classes, num_repeats,
  22. unigrams, distortion)
  23. print('res:', res)
  24. def test_fixed_unigram_candidate_sampler_02():
  25. foo_bar = torch.tensor([
  26. [0, 1, 0, 1],
  27. [0, 0, 0, 1],
  28. [0, 1, 0, 0],
  29. [1, 0, 0, 0],
  30. [0, 0, 1, 1]
  31. ], dtype=torch.float32)
  32. # bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce()
  33. bar_foo = foo_bar.to_sparse().coalesce()
  34. true_classes, row_count = get_true_classes(bar_foo)
  35. print('true_classes:', true_classes)
  36. print('row_count:', row_count)
  37. edges_pos, degrees = get_edges_and_degrees(bar_foo)
  38. print('degrees:', degrees)
  39. res = fixed_unigram_candidate_sampler(true_classes, row_count,
  40. degrees, 0.75)
  41. print('res:', res)
  42. def test_get_true_classes_01():
  43. adj_mat = torch.tensor([
  44. [0, 1, 0, 1, 0],
  45. [0, 0, 0, 0, 1],
  46. [1, 1, 0, 0, 0],
  47. [0, 0, 1, 0, 1],
  48. [0, 1, 0, 0, 0]
  49. ], dtype=torch.float).to_sparse()
  50. true_classes, row_count = get_true_classes(adj_mat)
  51. print('true_classes:', true_classes)
  52. true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
  53. assert torch.all(true_classes == torch.tensor([
  54. [1, 3],
  55. [1, 3],
  56. [4, -1],
  57. [0, 1],
  58. [0, 1],
  59. [2, 4],
  60. [2, 4],
  61. [1, -1]
  62. ]))
  63. def test_get_true_classes_02():
  64. adj_mat = torch.rand(2000, 2000).round().to_sparse()
  65. t = time.time()
  66. true_classes, row_count = get_true_classes(adj_mat)
  67. print('Elapsed:', time.time() - t)
  68. print('true_classes.shape:', true_classes.shape)
  69. def test_negative_sample_adj_mat_01():
  70. adj_mat = torch.tensor([
  71. [0, 1, 0, 1, 0],
  72. [0, 0, 0, 0, 1],
  73. [1, 1, 0, 0, 0],
  74. [0, 0, 1, 0, 1],
  75. [0, 1, 0, 0, 0]
  76. ])
  77. print('adj_mat:', adj_mat)
  78. adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse())
  79. print('adj_mat_neg:', adj_mat_neg.to_dense())
  80. def test_negative_sample_data_01():
  81. d = Data()
  82. d.add_vertex_type('Gene', 5)
  83. d.add_edge_type('Gene-Gene', 0, 0, [
  84. torch.tensor([
  85. [0, 1, 0, 1, 0],
  86. [0, 0, 0, 0, 1],
  87. [1, 1, 0, 0, 0],
  88. [0, 0, 1, 0, 1],
  89. [0, 1, 0, 0, 0]
  90. ], dtype=torch.float).to_sparse()
  91. ], dedicom_decoder)
  92. d_neg = negative_sample_data(d)
  93. def test_fixed_unigram_candidate_sampler_new_01():
  94. if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling):
  95. pytest.skip('fixed_unigram_candidate_sampler_new not found')
  96. x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse()
  97. true_classes, row_count = get_true_classes(x)
  98. edges, degrees = get_edges_and_degrees(x)
  99. # import pdb
  100. # pdb.set_trace()
  101. _ = triacontagon.sampling.fixed_unigram_candidate_sampler_new(true_classes,
  102. row_count, degrees, 0.75)