IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

317 lines
8.4KB

  1. from triacontagon.batch import _same_data_org, \
  2. DualBatcher, \
  3. Batcher
  4. from triacontagon.data import Data
  5. from triacontagon.decode import dedicom_decoder
  6. import torch
  7. def test_same_data_org_01():
  8. data = Data()
  9. assert _same_data_org(data, data)
  10. data.add_vertex_type('Foo', 10)
  11. assert _same_data_org(data, data)
  12. data.add_vertex_type('Bar', 10)
  13. assert _same_data_org(data, data)
  14. data_1 = Data()
  15. assert not _same_data_org(data, data_1)
  16. data_1.add_vertex_type('Foo', 10)
  17. assert not _same_data_org(data, data_1)
  18. data_1.add_vertex_type('Bar', 10)
  19. assert _same_data_org(data, data_1)
  20. def test_same_data_org_02():
  21. data = Data()
  22. data.add_vertex_type('Foo', 4)
  23. data.add_edge_type('Foo-Foo', 0, 0, [
  24. torch.tensor([
  25. [0, 0, 0, 1],
  26. [1, 0, 0, 0],
  27. [0, 1, 0, 1],
  28. [1, 0, 1, 0]
  29. ]).to_sparse()
  30. ], dedicom_decoder)
  31. assert _same_data_org(data, data)
  32. data_1 = Data()
  33. data_1.add_vertex_type('Foo', 4)
  34. data_1.add_edge_type('Foo-Foo', 0, 0, [
  35. torch.tensor([
  36. [0, 0, 0, 1],
  37. [1, 0, 0, 0],
  38. [0, 1, 0, 1],
  39. [1, 0, 0, 0]
  40. ]).to_sparse()
  41. ], dedicom_decoder)
  42. assert not _same_data_org(data, data_1)
  43. def test_batcher_01():
  44. d = Data()
  45. d.add_vertex_type('Gene', 5)
  46. d.add_edge_type('Gene-Gene', 0, 0, [
  47. torch.tensor([
  48. [0, 1, 0, 1, 0],
  49. [0, 0, 0, 0, 1],
  50. [1, 0, 0, 0, 0],
  51. [0, 0, 1, 0, 0],
  52. [0, 0, 0, 1, 0]
  53. ]).to_sparse()
  54. ], dedicom_decoder)
  55. b = Batcher(d, batch_size=1)
  56. visited = set()
  57. for t in b:
  58. print(t)
  59. k = tuple(t.edges[0].tolist())
  60. visited.add(k)
  61. assert visited == { (0, 1), (0, 3),
  62. (1, 4), (2, 0), (3, 2), (4, 3) }
  63. def test_batcher_02():
  64. d = Data()
  65. d.add_vertex_type('Gene', 5)
  66. d.add_edge_type('Gene-Gene', 0, 0, [
  67. torch.tensor([
  68. [0, 1, 0, 1, 0],
  69. [0, 0, 0, 0, 1],
  70. [1, 0, 0, 0, 0],
  71. [0, 0, 1, 0, 0],
  72. [0, 0, 0, 1, 0]
  73. ]).to_sparse(),
  74. torch.tensor([
  75. [0, 0, 1, 0, 1],
  76. [0, 0, 0, 1, 0],
  77. [0, 0, 0, 0, 1],
  78. [0, 1, 0, 0, 0],
  79. [0, 0, 1, 0, 0]
  80. ]).to_sparse()
  81. ], dedicom_decoder)
  82. b = Batcher(d, batch_size=1)
  83. visited = set()
  84. for t in b:
  85. print(t)
  86. k = (t.relation_type_index,) + \
  87. tuple(t.edges[0].tolist())
  88. visited.add(k)
  89. assert visited == { (0, 0, 1), (0, 0, 3),
  90. (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3),
  91. (1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4),
  92. (1, 3, 1), (1, 4, 2) }
  93. def test_batcher_03():
  94. d = Data()
  95. d.add_vertex_type('Gene', 5)
  96. d.add_vertex_type('Drug', 4)
  97. d.add_edge_type('Gene-Gene', 0, 0, [
  98. torch.tensor([
  99. [0, 1, 0, 1, 0],
  100. [0, 0, 0, 0, 1],
  101. [1, 0, 0, 0, 0],
  102. [0, 0, 1, 0, 0],
  103. [0, 0, 0, 1, 0]
  104. ]).to_sparse(),
  105. torch.tensor([
  106. [0, 0, 1, 0, 1],
  107. [0, 0, 0, 1, 0],
  108. [0, 0, 0, 0, 1],
  109. [0, 1, 0, 0, 0],
  110. [0, 0, 1, 0, 0]
  111. ]).to_sparse()
  112. ], dedicom_decoder)
  113. d.add_edge_type('Gene-Drug', 0, 1, [
  114. torch.tensor([
  115. [0, 1, 0, 0],
  116. [1, 0, 0, 1],
  117. [0, 1, 0, 0],
  118. [0, 0, 1, 0],
  119. [0, 1, 1, 0]
  120. ]).to_sparse()
  121. ], dedicom_decoder)
  122. b = Batcher(d, batch_size=1)
  123. visited = set()
  124. for t in b:
  125. print(t)
  126. k = (t.vertex_type_row, t.vertex_type_column,
  127. t.relation_type_index,) + \
  128. tuple(t.edges[0].tolist())
  129. visited.add(k)
  130. assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  131. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  132. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  133. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  134. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  135. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  136. (0, 1, 0, 4, 2) }
  137. def test_batcher_04():
  138. d = Data()
  139. d.add_vertex_type('Gene', 5)
  140. d.add_edge_type('Gene-Gene', 0, 0, [
  141. torch.tensor([
  142. [0, 1, 0, 1, 0],
  143. [0, 0, 0, 0, 1],
  144. [1, 0, 0, 0, 0],
  145. [0, 0, 1, 0, 0],
  146. [0, 0, 0, 1, 0]
  147. ]).to_sparse()
  148. ], dedicom_decoder)
  149. b = Batcher(d, batch_size=3)
  150. visited = set()
  151. for t in b:
  152. print(t)
  153. for e in t.edges:
  154. k = tuple(e.tolist())
  155. visited.add(k)
  156. assert visited == { (0, 1), (0, 3),
  157. (1, 4), (2, 0), (3, 2), (4, 3) }
  158. def test_batcher_05():
  159. d = Data()
  160. d.add_vertex_type('Gene', 5)
  161. d.add_vertex_type('Drug', 4)
  162. d.add_edge_type('Gene-Gene', 0, 0, [
  163. torch.tensor([
  164. [0, 1, 0, 1, 0],
  165. [0, 0, 0, 0, 1],
  166. [1, 0, 0, 0, 0],
  167. [0, 0, 1, 0, 0],
  168. [0, 0, 0, 1, 0]
  169. ]).to_sparse(),
  170. torch.tensor([
  171. [0, 0, 1, 0, 1],
  172. [0, 0, 0, 1, 0],
  173. [0, 0, 0, 0, 1],
  174. [0, 1, 0, 0, 0],
  175. [0, 0, 1, 0, 0]
  176. ]).to_sparse()
  177. ], dedicom_decoder)
  178. d.add_edge_type('Gene-Drug', 0, 1, [
  179. torch.tensor([
  180. [0, 1, 0, 0],
  181. [1, 0, 0, 1],
  182. [0, 1, 0, 0],
  183. [0, 0, 1, 0],
  184. [0, 1, 1, 0]
  185. ]).to_sparse()
  186. ], dedicom_decoder)
  187. b = Batcher(d, batch_size=5)
  188. visited = set()
  189. for t in b:
  190. print(t)
  191. for e in t.edges:
  192. k = (t.vertex_type_row, t.vertex_type_column,
  193. t.relation_type_index,) + \
  194. tuple(e.tolist())
  195. visited.add(k)
  196. assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  197. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  198. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  199. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  200. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  201. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  202. (0, 1, 0, 4, 2) }
  203. def test_dual_batcher_01():
  204. d = Data()
  205. d.add_vertex_type('Gene', 5)
  206. d.add_vertex_type('Drug', 4)
  207. d.add_edge_type('Gene-Gene', 0, 0, [
  208. torch.tensor([
  209. [0, 1, 0, 1, 0],
  210. [0, 0, 0, 0, 1],
  211. [1, 0, 0, 0, 0],
  212. [0, 0, 1, 0, 0],
  213. [0, 0, 0, 1, 0]
  214. ]).to_sparse(),
  215. torch.tensor([
  216. [0, 0, 1, 0, 1],
  217. [0, 0, 0, 1, 0],
  218. [0, 0, 0, 0, 1],
  219. [0, 1, 0, 0, 0],
  220. [0, 0, 1, 0, 0]
  221. ]).to_sparse()
  222. ], dedicom_decoder)
  223. d.add_edge_type('Gene-Drug', 0, 1, [
  224. torch.tensor([
  225. [0, 1, 0, 0],
  226. [1, 0, 0, 1],
  227. [0, 1, 0, 0],
  228. [0, 0, 1, 0],
  229. [0, 1, 1, 0]
  230. ]).to_sparse()
  231. ], dedicom_decoder)
  232. b = DualBatcher(d, d, batch_size=5)
  233. visited_pos = set()
  234. visited_neg = set()
  235. for t_pos, t_neg in b:
  236. assert t_pos.vertex_type_row == t_neg.vertex_type_row
  237. assert t_pos.vertex_type_column == t_neg.vertex_type_column
  238. assert t_pos.relation_type_index == t_neg.relation_type_index
  239. assert len(t_pos.edges) == len(t_neg.edges)
  240. for e in t_pos.edges:
  241. k = (t_pos.vertex_type_row, t_pos.vertex_type_column,
  242. t_pos.relation_type_index,) + \
  243. tuple(e.tolist())
  244. visited_pos.add(k)
  245. for e in t_neg.edges:
  246. k = (t_neg.vertex_type_row, t_neg.vertex_type_column,
  247. t_neg.relation_type_index,) + \
  248. tuple(e.tolist())
  249. visited_neg.add(k)
  250. expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  251. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  252. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  253. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  254. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  255. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  256. (0, 1, 0, 4, 2) }
  257. assert visited_pos == expected
  258. assert visited_neg == expected