IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.2KB

  1. from triacontagon.split import split_adj_mat, \
  2. split_edge_type
  3. from triacontagon.util import _equal
  4. from triacontagon.data import EdgeType
  5. import torch
  6. def test_split_adj_mat_01():
  7. adj_mat = torch.tensor([
  8. [0, 1, 0, 0, 1],
  9. [0, 0, 1, 0, 1],
  10. [1, 0, 0, 1, 0],
  11. [0, 0, 1, 1, 0]
  12. ]).to_sparse()
  13. (res,) = split_adj_mat(adj_mat, (1.,))
  14. assert torch.all(_equal(res, adj_mat))
  15. def test_split_adj_mat_02():
  16. adj_mat = torch.tensor([
  17. [0, 1, 0, 0, 1],
  18. [0, 0, 1, 0, 1],
  19. [1, 0, 0, 1, 0],
  20. [0, 0, 1, 1, 0]
  21. ]).to_sparse()
  22. a, b = split_adj_mat(adj_mat, ( .5, .5 ))
  23. assert torch.all(_equal(a+b, adj_mat))
  24. def test_split_adj_mat_03():
  25. adj_mat = torch.tensor([
  26. [0, 1, 0, 0, 1],
  27. [0, 0, 1, 0, 1],
  28. [1, 0, 0, 1, 0],
  29. [0, 0, 1, 1, 0]
  30. ]).to_sparse()
  31. a, b, c = split_adj_mat(adj_mat, ( .8, .1, .1 ))
  32. print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense())
  33. assert torch.all(_equal(a+b+c, adj_mat))
  34. def test_split_edge_type_01():
  35. et = EdgeType('Dummy', 0, 1, [
  36. torch.tensor([
  37. [0, 1, 0, 0, 0],
  38. [0, 0, 1, 0, 1],
  39. [1, 0, 0, 0, 1],
  40. [0, 1, 0, 1, 0]
  41. ]).to_sparse()
  42. ], None, None)
  43. res = split_edge_type(et, (1.,))
  44. assert torch.all(_equal(et.adjacency_matrices[0],
  45. res[0].adjacency_matrices[0]))
  46. def test_split_edge_type_02():
  47. et = EdgeType('Dummy', 0, 1, [
  48. torch.tensor([
  49. [0, 1, 0, 0, 0],
  50. [0, 0, 1, 0, 1],
  51. [1, 0, 0, 0, 1],
  52. [0, 1, 0, 1, 0]
  53. ]).to_sparse()
  54. ], None, None)
  55. res = split_edge_type(et, (.5, .5))
  56. assert torch.all(_equal(et.adjacency_matrices[0],
  57. res[0].adjacency_matrices[0] + \
  58. res[1].adjacency_matrices[0]))
  59. def test_split_edge_type_03():
  60. et = EdgeType('Dummy', 0, 1, [
  61. torch.tensor([
  62. [0, 1, 0, 0, 0],
  63. [0, 0, 1, 0, 1],
  64. [1, 0, 0, 0, 1],
  65. [0, 1, 0, 1, 0]
  66. ]).to_sparse()
  67. ], None, None)
  68. res = split_edge_type(et, (.4, .4, .2))
  69. assert torch.all(_equal(et.adjacency_matrices[0],
  70. res[0].adjacency_matrices[0] + \
  71. res[1].adjacency_matrices[0] + \
  72. res[2].adjacency_matrices[0]))
  73. def test_split_edge_type_04():
  74. et = EdgeType('Dummy', 0, 1, [
  75. torch.tensor([
  76. [0, 1, 0, 0, 0],
  77. [0, 0, 1, 0, 1],
  78. [1, 0, 0, 0, 1],
  79. [0, 1, 0, 1, 0]
  80. ]).to_sparse(),
  81. torch.tensor([
  82. [1, 0, 0, 0, 0],
  83. [0, 1, 0, 1, 0],
  84. [0, 0, 1, 1, 0],
  85. [1, 0, 1, 0, 0]
  86. ]).to_sparse()
  87. ], None, None)
  88. res = split_edge_type(et, (.4, .4, .2))
  89. assert torch.all(_equal(et.adjacency_matrices[0],
  90. res[0].adjacency_matrices[0] + \
  91. res[1].adjacency_matrices[0] + \
  92. res[2].adjacency_matrices[0]))
  93. assert torch.all(_equal(et.adjacency_matrices[1],
  94. res[0].adjacency_matrices[1] + \
  95. res[1].adjacency_matrices[1] + \
  96. res[2].adjacency_matrices[1]))