IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
6.9KB

  1. from triacontagon.split import split_adj_mat, \
  2. split_edge_type, \
  3. split_data
  4. from triacontagon.util import _equal
  5. from triacontagon.data import EdgeType, \
  6. Data
  7. from triacontagon.decode import dedicom_decoder
  8. import torch
  9. def test_split_adj_mat_01():
  10. adj_mat = torch.tensor([
  11. [0, 1, 0, 0, 1],
  12. [0, 0, 1, 0, 1],
  13. [1, 0, 0, 1, 0],
  14. [0, 0, 1, 1, 0]
  15. ]).to_sparse()
  16. (res,) = split_adj_mat(adj_mat, (1.,))
  17. assert torch.all(_equal(res, adj_mat))
  18. def test_split_adj_mat_02():
  19. adj_mat = torch.tensor([
  20. [0, 1, 0, 0, 1],
  21. [0, 0, 1, 0, 1],
  22. [1, 0, 0, 1, 0],
  23. [0, 0, 1, 1, 0]
  24. ]).to_sparse()
  25. a, b = split_adj_mat(adj_mat, ( .5, .5 ))
  26. assert torch.all(_equal(a+b, adj_mat))
  27. def test_split_adj_mat_03():
  28. adj_mat = torch.tensor([
  29. [0, 1, 0, 0, 1],
  30. [0, 0, 1, 0, 1],
  31. [1, 0, 0, 1, 0],
  32. [0, 0, 1, 1, 0]
  33. ]).to_sparse()
  34. a, b, c = split_adj_mat(adj_mat, ( .8, .1, .1 ))
  35. print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense())
  36. assert torch.all(_equal(a+b+c, adj_mat))
  37. def test_split_edge_type_01():
  38. et = EdgeType('Dummy', 0, 1, [
  39. torch.tensor([
  40. [0, 1, 0, 0, 0],
  41. [0, 0, 1, 0, 1],
  42. [1, 0, 0, 0, 1],
  43. [0, 1, 0, 1, 0]
  44. ]).to_sparse()
  45. ], None, None)
  46. res = split_edge_type(et, (1.,))
  47. assert torch.all(_equal(et.adjacency_matrices[0],
  48. res[0].adjacency_matrices[0]))
  49. def test_split_edge_type_02():
  50. et = EdgeType('Dummy', 0, 1, [
  51. torch.tensor([
  52. [0, 1, 0, 0, 0],
  53. [0, 0, 1, 0, 1],
  54. [1, 0, 0, 0, 1],
  55. [0, 1, 0, 1, 0]
  56. ]).to_sparse()
  57. ], None, None)
  58. res = split_edge_type(et, (.5, .5))
  59. assert torch.all(_equal(et.adjacency_matrices[0],
  60. res[0].adjacency_matrices[0] + \
  61. res[1].adjacency_matrices[0]))
  62. def test_split_edge_type_03():
  63. et = EdgeType('Dummy', 0, 1, [
  64. torch.tensor([
  65. [0, 1, 0, 0, 0],
  66. [0, 0, 1, 0, 1],
  67. [1, 0, 0, 0, 1],
  68. [0, 1, 0, 1, 0]
  69. ]).to_sparse()
  70. ], None, None)
  71. res = split_edge_type(et, (.4, .4, .2))
  72. assert torch.all(_equal(et.adjacency_matrices[0],
  73. res[0].adjacency_matrices[0] + \
  74. res[1].adjacency_matrices[0] + \
  75. res[2].adjacency_matrices[0]))
  76. def test_split_edge_type_04():
  77. et = EdgeType('Dummy', 0, 1, [
  78. torch.tensor([
  79. [0, 1, 0, 0, 0],
  80. [0, 0, 1, 0, 1],
  81. [1, 0, 0, 0, 1],
  82. [0, 1, 0, 1, 0]
  83. ]).to_sparse(),
  84. torch.tensor([
  85. [1, 0, 0, 0, 0],
  86. [0, 1, 0, 1, 0],
  87. [0, 0, 1, 1, 0],
  88. [1, 0, 1, 0, 0]
  89. ]).to_sparse()
  90. ], None, None)
  91. res = split_edge_type(et, (.4, .4, .2))
  92. assert torch.all(_equal(et.adjacency_matrices[0],
  93. res[0].adjacency_matrices[0] + \
  94. res[1].adjacency_matrices[0] + \
  95. res[2].adjacency_matrices[0]))
  96. assert torch.all(_equal(et.adjacency_matrices[1],
  97. res[0].adjacency_matrices[1] + \
  98. res[1].adjacency_matrices[1] + \
  99. res[2].adjacency_matrices[1]))
  100. def test_split_data_01():
  101. data = Data()
  102. data.add_vertex_type('Foo', 5)
  103. data.add_vertex_type('Bar', 4)
  104. foo_foo = torch.tensor([
  105. [0, 1, 0, 1, 0],
  106. [0, 0, 0, 1, 0],
  107. [0, 1, 0, 0, 1],
  108. [0, 1, 0, 0, 0],
  109. [1, 0, 0, 1, 0]
  110. ], dtype=torch.float32)
  111. foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
  112. foo_bar = torch.tensor([
  113. [0, 1, 0, 1],
  114. [0, 0, 0, 1],
  115. [0, 1, 0, 0],
  116. [1, 0, 0, 0],
  117. [0, 0, 1, 1]
  118. ], dtype=torch.float32)
  119. bar_foo = foo_bar.transpose(0, 1)
  120. bar_bar = torch.tensor([
  121. [0, 0, 1, 0],
  122. [1, 0, 0, 0],
  123. [0, 1, 0, 1],
  124. [0, 1, 0, 0],
  125. ], dtype=torch.float32)
  126. bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
  127. data.add_edge_type('Foo-Foo', 0, 0, [
  128. foo_foo.to_sparse().coalesce()
  129. ], dedicom_decoder)
  130. data.add_edge_type('Foo-Bar', 0, 1, [
  131. foo_bar.to_sparse().coalesce()
  132. ], dedicom_decoder)
  133. data.add_edge_type('Bar-Foo', 1, 0, [
  134. bar_foo.to_sparse().coalesce()
  135. ], dedicom_decoder)
  136. data.add_edge_type('Bar-Bar', 1, 1, [
  137. bar_bar.to_sparse().coalesce()
  138. ], dedicom_decoder)
  139. (res,) = split_data(data, (1.,))
  140. assert torch.all(_equal(res.edge_types[0, 0].adjacency_matrices[0],
  141. data.edge_types[0, 0].adjacency_matrices[0]))
  142. assert torch.all(_equal(res.edge_types[0, 1].adjacency_matrices[0],
  143. data.edge_types[0, 1].adjacency_matrices[0]))
  144. assert torch.all(_equal(res.edge_types[1, 0].adjacency_matrices[0],
  145. data.edge_types[1, 0].adjacency_matrices[0]))
  146. assert torch.all(_equal(res.edge_types[1, 1].adjacency_matrices[0],
  147. data.edge_types[1, 1].adjacency_matrices[0]))
  148. def test_split_data_02():
  149. data = Data()
  150. data.add_vertex_type('Foo', 5)
  151. data.add_vertex_type('Bar', 4)
  152. foo_foo = torch.tensor([
  153. [0, 1, 0, 1, 0],
  154. [0, 0, 0, 1, 0],
  155. [0, 1, 0, 0, 1],
  156. [0, 1, 0, 0, 0],
  157. [1, 0, 0, 1, 0]
  158. ], dtype=torch.float32)
  159. foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
  160. foo_bar = torch.tensor([
  161. [0, 1, 0, 1],
  162. [0, 0, 0, 1],
  163. [0, 1, 0, 0],
  164. [1, 0, 0, 0],
  165. [0, 0, 1, 1]
  166. ], dtype=torch.float32)
  167. bar_foo = foo_bar.transpose(0, 1)
  168. bar_bar = torch.tensor([
  169. [0, 0, 1, 0],
  170. [1, 0, 0, 0],
  171. [0, 1, 0, 1],
  172. [0, 1, 0, 0],
  173. ], dtype=torch.float32)
  174. bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
  175. data.add_edge_type('Foo-Foo', 0, 0, [
  176. foo_foo.to_sparse().coalesce()
  177. ], dedicom_decoder)
  178. data.add_edge_type('Foo-Bar', 0, 1, [
  179. foo_bar.to_sparse().coalesce()
  180. ], dedicom_decoder)
  181. data.add_edge_type('Bar-Foo', 1, 0, [
  182. bar_foo.to_sparse().coalesce()
  183. ], dedicom_decoder)
  184. data.add_edge_type('Bar-Bar', 1, 1, [
  185. bar_bar.to_sparse().coalesce()
  186. ], dedicom_decoder)
  187. a, b = split_data(data, (.5,.5))
  188. assert torch.all(_equal(a.edge_types[0, 0].adjacency_matrices[0] + \
  189. b.edge_types[0, 0].adjacency_matrices[0],
  190. data.edge_types[0, 0].adjacency_matrices[0]))
  191. assert torch.all(_equal(a.edge_types[0, 1].adjacency_matrices[0] + \
  192. b.edge_types[0, 1].adjacency_matrices[0],
  193. data.edge_types[0, 1].adjacency_matrices[0]))
  194. assert torch.all(_equal(a.edge_types[1, 0].adjacency_matrices[0] + \
  195. b.edge_types[1, 0].adjacency_matrices[0],
  196. data.edge_types[1, 0].adjacency_matrices[0]))
  197. assert torch.all(_equal(a.edge_types[1, 1].adjacency_matrices[0] + \
  198. b.edge_types[1, 1].adjacency_matrices[0],
  199. data.edge_types[1, 1].adjacency_matrices[0]))