IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
1.7KB

  1. from triacontagon.data import Data
  2. from triacontagon.sampling import get_true_classes, \
  3. negative_sample_adj_mat, \
  4. negative_sample_data
  5. from triacontagon.decode import dedicom_decoder
  6. import torch
  7. import time
  8. def test_get_true_classes_01():
  9. adj_mat = torch.tensor([
  10. [0, 1, 0, 1, 0],
  11. [0, 0, 0, 0, 1],
  12. [1, 1, 0, 0, 0],
  13. [0, 0, 1, 0, 1],
  14. [0, 1, 0, 0, 0]
  15. ], dtype=torch.float).to_sparse()
  16. true_classes = get_true_classes(adj_mat)
  17. print('true_classes:', true_classes)
  18. assert torch.all(true_classes == torch.tensor([
  19. [1, 3],
  20. [4, -1],
  21. [0, 1],
  22. [2, 4],
  23. [1, -1]
  24. ]))
  25. def test_get_true_classes_02():
  26. adj_mat = torch.rand(2000, 2000).round().to_sparse()
  27. t = time.time()
  28. true_classes = get_true_classes(adj_mat)
  29. print('Elapsed:', time.time() - t)
  30. print('true_classes.shape:', true_classes.shape)
  31. def test_negative_sample_adj_mat_01():
  32. adj_mat = torch.tensor([
  33. [0, 1, 0, 1, 0],
  34. [0, 0, 0, 0, 1],
  35. [1, 1, 0, 0, 0],
  36. [0, 0, 1, 0, 1],
  37. [0, 1, 0, 0, 0]
  38. ])
  39. print('adj_mat:', adj_mat)
  40. adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse())
  41. print('adj_mat_neg:', adj_mat_neg.to_dense())
  42. def test_negative_sample_data_01():
  43. d = Data()
  44. d.add_vertex_type('Gene', 5)
  45. d.add_edge_type('Gene-Gene', 0, 0, [
  46. torch.tensor([
  47. [0, 1, 0, 1, 0],
  48. [0, 0, 0, 0, 1],
  49. [1, 1, 0, 0, 0],
  50. [0, 0, 1, 0, 1],
  51. [0, 1, 0, 0, 0]
  52. ], dtype=torch.float).to_sparse()
  53. ], dedicom_decoder)
  54. d_neg = negative_sample_data(d)