IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.2KB

  1. from triacontagon.batch import Batcher
  2. from triacontagon.data import Data
  3. from triacontagon.decode import dedicom_decoder
  4. import torch
  5. def test_batcher_01():
  6. d = Data()
  7. d.add_vertex_type('Gene', 5)
  8. d.add_edge_type('Gene-Gene', 0, 0, [
  9. torch.tensor([
  10. [0, 1, 0, 1, 0],
  11. [0, 0, 0, 0, 1],
  12. [1, 0, 0, 0, 0],
  13. [0, 0, 1, 0, 0],
  14. [0, 0, 0, 1, 0]
  15. ]).to_sparse()
  16. ], dedicom_decoder)
  17. b = Batcher(d, batch_size=1)
  18. visited = set()
  19. for t in b:
  20. print(t)
  21. k = tuple(t.edges[0].tolist())
  22. visited.add(k)
  23. assert visited == { (0, 1), (0, 3),
  24. (1, 4), (2, 0), (3, 2), (4, 3) }
  25. def test_batcher_02():
  26. d = Data()
  27. d.add_vertex_type('Gene', 5)
  28. d.add_edge_type('Gene-Gene', 0, 0, [
  29. torch.tensor([
  30. [0, 1, 0, 1, 0],
  31. [0, 0, 0, 0, 1],
  32. [1, 0, 0, 0, 0],
  33. [0, 0, 1, 0, 0],
  34. [0, 0, 0, 1, 0]
  35. ]).to_sparse(),
  36. torch.tensor([
  37. [1, 0, 1, 0, 0],
  38. [0, 0, 0, 1, 0],
  39. [0, 0, 0, 0, 1],
  40. [0, 1, 0, 0, 0],
  41. [0, 0, 1, 0, 0]
  42. ]).to_sparse()
  43. ], dedicom_decoder)
  44. b = Batcher(d, batch_size=1)
  45. visited = set()
  46. for t in b:
  47. print(t)
  48. k = (t.relation_type_index,) + \
  49. tuple(t.edges[0].tolist())
  50. visited.add(k)
  51. assert visited == { (0, 0, 1), (0, 0, 3),
  52. (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3),
  53. (1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4),
  54. (1, 3, 1), (1, 4, 2) }
  55. def test_batcher_03():
  56. d = Data()
  57. d.add_vertex_type('Gene', 5)
  58. d.add_vertex_type('Drug', 4)
  59. d.add_edge_type('Gene-Gene', 0, 0, [
  60. torch.tensor([
  61. [0, 1, 0, 1, 0],
  62. [0, 0, 0, 0, 1],
  63. [1, 0, 0, 0, 0],
  64. [0, 0, 1, 0, 0],
  65. [0, 0, 0, 1, 0]
  66. ]).to_sparse(),
  67. torch.tensor([
  68. [1, 0, 1, 0, 0],
  69. [0, 0, 0, 1, 0],
  70. [0, 0, 0, 0, 1],
  71. [0, 1, 0, 0, 0],
  72. [0, 0, 1, 0, 0]
  73. ]).to_sparse()
  74. ], dedicom_decoder)
  75. d.add_edge_type('Gene-Drug', 0, 1, [
  76. torch.tensor([
  77. [0, 1, 0, 0],
  78. [1, 0, 0, 1],
  79. [0, 1, 0, 0],
  80. [0, 0, 1, 0],
  81. [0, 1, 1, 0]
  82. ]).to_sparse()
  83. ], dedicom_decoder)
  84. b = Batcher(d, batch_size=1)
  85. visited = set()
  86. for t in b:
  87. print(t)
  88. k = (t.vertex_type_row, t.vertex_type_column,
  89. t.relation_type_index,) + \
  90. tuple(t.edges[0].tolist())
  91. visited.add(k)
  92. assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  93. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  94. (0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  95. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  96. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  97. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  98. (0, 1, 0, 4, 2) }