IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

116 lines
3.0KB

  1. from triacontagon.data import Data
  2. from triacontagon.sampling import fixed_unigram_candidate_sampler, \
  3. get_true_classes, \
  4. negative_sample_adj_mat, \
  5. negative_sample_data, \
  6. get_edges_and_degrees
  7. from triacontagon.decode import dedicom_decoder
  8. import torch
  9. import time
  10. def test_fixed_unigram_candidate_sampler_01():
  11. true_classes = torch.tensor([[-1],
  12. [-1],
  13. [ 3],
  14. [ 2],
  15. [-1]])
  16. num_repeats = torch.tensor([0, 0, 1, 1, 0])
  17. unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64)
  18. distortion = 0.75
  19. res = fixed_unigram_candidate_sampler(true_classes, num_repeats,
  20. unigrams, distortion)
  21. print('res:', res)
  22. def test_fixed_unigram_candidate_sampler_02():
  23. foo_bar = torch.tensor([
  24. [0, 1, 0, 1],
  25. [0, 0, 0, 1],
  26. [0, 1, 0, 0],
  27. [1, 0, 0, 0],
  28. [0, 0, 1, 1]
  29. ], dtype=torch.float32)
  30. # bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce()
  31. bar_foo = foo_bar.to_sparse().coalesce()
  32. true_classes, row_count = get_true_classes(bar_foo)
  33. print('true_classes:', true_classes)
  34. print('row_count:', row_count)
  35. edges_pos, degrees = get_edges_and_degrees(bar_foo)
  36. res = fixed_unigram_candidate_sampler(true_classes, row_count,
  37. degrees, 0.75)
  38. print('res:', res)
  39. def test_get_true_classes_01():
  40. adj_mat = torch.tensor([
  41. [0, 1, 0, 1, 0],
  42. [0, 0, 0, 0, 1],
  43. [1, 1, 0, 0, 0],
  44. [0, 0, 1, 0, 1],
  45. [0, 1, 0, 0, 0]
  46. ], dtype=torch.float).to_sparse()
  47. true_classes, row_count = get_true_classes(adj_mat)
  48. print('true_classes:', true_classes)
  49. true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
  50. assert torch.all(true_classes == torch.tensor([
  51. [1, 3],
  52. [1, 3],
  53. [4, -1],
  54. [0, 1],
  55. [0, 1],
  56. [2, 4],
  57. [2, 4],
  58. [1, -1]
  59. ]))
  60. def test_get_true_classes_02():
  61. adj_mat = torch.rand(2000, 2000).round().to_sparse()
  62. t = time.time()
  63. true_classes, row_count = get_true_classes(adj_mat)
  64. print('Elapsed:', time.time() - t)
  65. print('true_classes.shape:', true_classes.shape)
  66. def test_negative_sample_adj_mat_01():
  67. adj_mat = torch.tensor([
  68. [0, 1, 0, 1, 0],
  69. [0, 0, 0, 0, 1],
  70. [1, 1, 0, 0, 0],
  71. [0, 0, 1, 0, 1],
  72. [0, 1, 0, 0, 0]
  73. ])
  74. print('adj_mat:', adj_mat)
  75. adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse())
  76. print('adj_mat_neg:', adj_mat_neg.to_dense())
  77. def test_negative_sample_data_01():
  78. d = Data()
  79. d.add_vertex_type('Gene', 5)
  80. d.add_edge_type('Gene-Gene', 0, 0, [
  81. torch.tensor([
  82. [0, 1, 0, 1, 0],
  83. [0, 0, 0, 0, 1],
  84. [1, 1, 0, 0, 0],
  85. [0, 0, 1, 0, 1],
  86. [0, 1, 0, 0, 0]
  87. ], dtype=torch.float).to_sparse()
  88. ], dedicom_decoder)
  89. d_neg = negative_sample_data(d)