IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

317 Zeilen
8.4KB

  1. from triacontagon.batch import _same_data_org, \
  2. DualBatcher, \
  3. Batcher
  4. from triacontagon.data import Data
  5. from triacontagon.decode import dedicom_decoder
  6. import torch
  7. def test_same_data_org_01():
  8. data = Data()
  9. assert _same_data_org(data, data)
  10. data.add_vertex_type('Foo', 10)
  11. assert _same_data_org(data, data)
  12. data.add_vertex_type('Bar', 10)
  13. assert _same_data_org(data, data)
  14. data_1 = Data()
  15. assert not _same_data_org(data, data_1)
  16. data_1.add_vertex_type('Foo', 10)
  17. assert not _same_data_org(data, data_1)
  18. data_1.add_vertex_type('Bar', 10)
  19. assert _same_data_org(data, data_1)
  20. def test_same_data_org_02():
  21. data = Data()
  22. data.add_vertex_type('Foo', 4)
  23. data.add_edge_type('Foo-Foo', 0, 0, [
  24. torch.tensor([
  25. [0, 0, 0, 1],
  26. [1, 0, 0, 0],
  27. [0, 1, 0, 1],
  28. [1, 0, 1, 0]
  29. ]).to_sparse()
  30. ], dedicom_decoder)
  31. assert _same_data_org(data, data)
  32. data_1 = Data()
  33. data_1.add_vertex_type('Foo', 4)
  34. data_1.add_edge_type('Foo-Foo', 0, 0, [
  35. torch.tensor([
  36. [0, 0, 0, 1],
  37. [1, 0, 0, 0],
  38. [0, 1, 0, 1],
  39. [1, 0, 0, 0]
  40. ]).to_sparse()
  41. ], dedicom_decoder)
  42. assert not _same_data_org(data, data_1)
  43. def test_batcher_01():
  44. d = Data()
  45. d.add_vertex_type('Gene', 5)
  46. d.add_edge_type('Gene-Gene', 0, 0, [
  47. torch.tensor([
  48. [0, 1, 0, 1, 0],
  49. [0, 0, 0, 0, 1],
  50. [1, 0, 0, 0, 0],
  51. [0, 0, 1, 0, 0],
  52. [0, 0, 0, 1, 0]
  53. ]).to_sparse()
  54. ], dedicom_decoder)
  55. b = Batcher(d, batch_size=1)
  56. visited = set()
  57. for t in b:
  58. print(t)
  59. k = tuple(t.edges[0].tolist())
  60. visited.add(k)
  61. assert visited == { (0, 1), (0, 3),
  62. (1, 4), (2, 0), (3, 2), (4, 3) }
  63. def test_batcher_02():
  64. d = Data()
  65. d.add_vertex_type('Gene', 5)
  66. d.add_edge_type('Gene-Gene', 0, 0, [
  67. torch.tensor([
  68. [0, 1, 0, 1, 0],
  69. [0, 0, 0, 0, 1],
  70. [1, 0, 0, 0, 0],
  71. [0, 0, 1, 0, 0],
  72. [0, 0, 0, 1, 0]
  73. ]).to_sparse(),
  74. torch.tensor([
  75. [0, 0, 1, 0, 1],
  76. [0, 0, 0, 1, 0],
  77. [0, 0, 0, 0, 1],
  78. [0, 1, 0, 0, 0],
  79. [0, 0, 1, 0, 0]
  80. ]).to_sparse()
  81. ], dedicom_decoder)
  82. b = Batcher(d, batch_size=1)
  83. visited = set()
  84. for t in b:
  85. print(t)
  86. k = (t.relation_type_index,) + \
  87. tuple(t.edges[0].tolist())
  88. visited.add(k)
  89. assert visited == { (0, 0, 1), (0, 0, 3),
  90. (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3),
  91. (1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4),
  92. (1, 3, 1), (1, 4, 2) }
  93. def test_batcher_03():
  94. d = Data()
  95. d.add_vertex_type('Gene', 5)
  96. d.add_vertex_type('Drug', 4)
  97. d.add_edge_type('Gene-Gene', 0, 0, [
  98. torch.tensor([
  99. [0, 1, 0, 1, 0],
  100. [0, 0, 0, 0, 1],
  101. [1, 0, 0, 0, 0],
  102. [0, 0, 1, 0, 0],
  103. [0, 0, 0, 1, 0]
  104. ]).to_sparse(),
  105. torch.tensor([
  106. [0, 0, 1, 0, 1],
  107. [0, 0, 0, 1, 0],
  108. [0, 0, 0, 0, 1],
  109. [0, 1, 0, 0, 0],
  110. [0, 0, 1, 0, 0]
  111. ]).to_sparse()
  112. ], dedicom_decoder)
  113. d.add_edge_type('Gene-Drug', 0, 1, [
  114. torch.tensor([
  115. [0, 1, 0, 0],
  116. [1, 0, 0, 1],
  117. [0, 1, 0, 0],
  118. [0, 0, 1, 0],
  119. [0, 1, 1, 0]
  120. ]).to_sparse()
  121. ], dedicom_decoder)
  122. b = Batcher(d, batch_size=1)
  123. visited = set()
  124. for t in b:
  125. print(t)
  126. k = (t.vertex_type_row, t.vertex_type_column,
  127. t.relation_type_index,) + \
  128. tuple(t.edges[0].tolist())
  129. visited.add(k)
  130. assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  131. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  132. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  133. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  134. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  135. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  136. (0, 1, 0, 4, 2) }
  137. def test_batcher_04():
  138. d = Data()
  139. d.add_vertex_type('Gene', 5)
  140. d.add_edge_type('Gene-Gene', 0, 0, [
  141. torch.tensor([
  142. [0, 1, 0, 1, 0],
  143. [0, 0, 0, 0, 1],
  144. [1, 0, 0, 0, 0],
  145. [0, 0, 1, 0, 0],
  146. [0, 0, 0, 1, 0]
  147. ]).to_sparse()
  148. ], dedicom_decoder)
  149. b = Batcher(d, batch_size=3)
  150. visited = set()
  151. for t in b:
  152. print(t)
  153. for e in t.edges:
  154. k = tuple(e.tolist())
  155. visited.add(k)
  156. assert visited == { (0, 1), (0, 3),
  157. (1, 4), (2, 0), (3, 2), (4, 3) }
  158. def test_batcher_05():
  159. d = Data()
  160. d.add_vertex_type('Gene', 5)
  161. d.add_vertex_type('Drug', 4)
  162. d.add_edge_type('Gene-Gene', 0, 0, [
  163. torch.tensor([
  164. [0, 1, 0, 1, 0],
  165. [0, 0, 0, 0, 1],
  166. [1, 0, 0, 0, 0],
  167. [0, 0, 1, 0, 0],
  168. [0, 0, 0, 1, 0]
  169. ]).to_sparse(),
  170. torch.tensor([
  171. [0, 0, 1, 0, 1],
  172. [0, 0, 0, 1, 0],
  173. [0, 0, 0, 0, 1],
  174. [0, 1, 0, 0, 0],
  175. [0, 0, 1, 0, 0]
  176. ]).to_sparse()
  177. ], dedicom_decoder)
  178. d.add_edge_type('Gene-Drug', 0, 1, [
  179. torch.tensor([
  180. [0, 1, 0, 0],
  181. [1, 0, 0, 1],
  182. [0, 1, 0, 0],
  183. [0, 0, 1, 0],
  184. [0, 1, 1, 0]
  185. ]).to_sparse()
  186. ], dedicom_decoder)
  187. b = Batcher(d, batch_size=5)
  188. visited = set()
  189. for t in b:
  190. print(t)
  191. for e in t.edges:
  192. k = (t.vertex_type_row, t.vertex_type_column,
  193. t.relation_type_index,) + \
  194. tuple(e.tolist())
  195. visited.add(k)
  196. assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  197. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  198. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  199. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  200. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  201. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  202. (0, 1, 0, 4, 2) }
  203. def test_dual_batcher_01():
  204. d = Data()
  205. d.add_vertex_type('Gene', 5)
  206. d.add_vertex_type('Drug', 4)
  207. d.add_edge_type('Gene-Gene', 0, 0, [
  208. torch.tensor([
  209. [0, 1, 0, 1, 0],
  210. [0, 0, 0, 0, 1],
  211. [1, 0, 0, 0, 0],
  212. [0, 0, 1, 0, 0],
  213. [0, 0, 0, 1, 0]
  214. ]).to_sparse(),
  215. torch.tensor([
  216. [0, 0, 1, 0, 1],
  217. [0, 0, 0, 1, 0],
  218. [0, 0, 0, 0, 1],
  219. [0, 1, 0, 0, 0],
  220. [0, 0, 1, 0, 0]
  221. ]).to_sparse()
  222. ], dedicom_decoder)
  223. d.add_edge_type('Gene-Drug', 0, 1, [
  224. torch.tensor([
  225. [0, 1, 0, 0],
  226. [1, 0, 0, 1],
  227. [0, 1, 0, 0],
  228. [0, 0, 1, 0],
  229. [0, 1, 1, 0]
  230. ]).to_sparse()
  231. ], dedicom_decoder)
  232. b = DualBatcher(d, d, batch_size=5)
  233. visited_pos = set()
  234. visited_neg = set()
  235. for t_pos, t_neg in b:
  236. assert t_pos.vertex_type_row == t_neg.vertex_type_row
  237. assert t_pos.vertex_type_column == t_neg.vertex_type_column
  238. assert t_pos.relation_type_index == t_neg.relation_type_index
  239. assert len(t_pos.edges) == len(t_neg.edges)
  240. for e in t_pos.edges:
  241. k = (t_pos.vertex_type_row, t_pos.vertex_type_column,
  242. t_pos.relation_type_index,) + \
  243. tuple(e.tolist())
  244. visited_pos.add(k)
  245. for e in t_neg.edges:
  246. k = (t_neg.vertex_type_row, t_neg.vertex_type_column,
  247. t_neg.relation_type_index,) + \
  248. tuple(e.tolist())
  249. visited_neg.add(k)
  250. expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
  251. (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
  252. (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
  253. (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
  254. (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
  255. (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
  256. (0, 1, 0, 4, 2) }
  257. assert visited_pos == expected
  258. assert visited_neg == expected