IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

test_normalize.py 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from icosagon.normalize import add_eye_sparse, \
  2. norm_adj_mat_one_node_type_sparse, \
  3. norm_adj_mat_one_node_type_dense, \
  4. norm_adj_mat_one_node_type
  5. import decagon_pytorch.normalize
  6. import torch
  7. import pytest
  8. import numpy as np
  9. from math import sqrt
  10. def test_add_eye_sparse_01():
  11. adj_mat_dense = torch.rand((10, 10))
  12. adj_mat_sparse = adj_mat_dense.to_sparse()
  13. adj_mat_dense += torch.eye(10)
  14. adj_mat_sparse = add_eye_sparse(adj_mat_sparse)
  15. assert torch.all(adj_mat_sparse.to_dense() == adj_mat_dense)
  16. def test_add_eye_sparse_02():
  17. adj_mat_dense = torch.rand((10, 20))
  18. adj_mat_sparse = adj_mat_dense.to_sparse()
  19. with pytest.raises(ValueError):
  20. _ = add_eye_sparse(adj_mat_sparse)
  21. def test_add_eye_sparse_03():
  22. adj_mat_dense = torch.rand((10, 10))
  23. with pytest.raises(ValueError):
  24. _ = add_eye_sparse(adj_mat_dense)
  25. def test_add_eye_sparse_04():
  26. adj_mat_dense = np.random.rand(10, 10)
  27. with pytest.raises(ValueError):
  28. _ = add_eye_sparse(adj_mat_dense)
  29. def test_norm_adj_mat_one_node_type_sparse_01():
  30. adj_mat = torch.rand((10, 10))
  31. adj_mat = (adj_mat > .5)
  32. adj_mat = adj_mat.to_sparse()
  33. _ = norm_adj_mat_one_node_type_sparse(adj_mat)
  34. def test_norm_adj_mat_one_node_type_sparse_02():
  35. adj_mat_dense = torch.rand((10, 10))
  36. adj_mat_dense = (adj_mat_dense > .5)
  37. adj_mat_sparse = adj_mat_dense.to_sparse()
  38. adj_mat_sparse = norm_adj_mat_one_node_type_sparse(adj_mat_sparse)
  39. adj_mat_dense = norm_adj_mat_one_node_type_dense(adj_mat_dense)
  40. assert torch.all(adj_mat_sparse.to_dense() - adj_mat_dense < 0.000001)
  41. def test_norm_adj_mat_one_node_type_dense_01():
  42. adj_mat = torch.rand((10, 10))
  43. adj_mat = (adj_mat > .5)
  44. _ = norm_adj_mat_one_node_type_dense(adj_mat)
  45. def test_norm_adj_mat_one_node_type_dense_02():
  46. adj_mat = torch.tensor([
  47. [0, 1, 1, 0], # 3
  48. [1, 0, 1, 0], # 3
  49. [1, 1, 0, 1], # 4
  50. [0, 0, 1, 0] # 2
  51. # 3 3 4 2
  52. ])
  53. expect_denom = np.array([
  54. [ 3, 3, sqrt(3)*2, sqrt(6) ],
  55. [ 3, 3, sqrt(3)*2, sqrt(6) ],
  56. [ sqrt(3)*2, sqrt(3)*2, 4, sqrt(2)*2 ],
  57. [ sqrt(6), sqrt(6), sqrt(2)*2, 2 ]
  58. ], dtype=np.float32)
  59. expect = (adj_mat.detach().cpu().numpy().astype(np.float32) + np.eye(4)) / expect_denom
  60. # expect = np.array([
  61. # [1/3, 1/3, 1/3, 0],
  62. # [1/3, 1/3, 1/3, 0],
  63. # [1/4, 1/4, 1/4, 1/4],
  64. # [0, 0, 1/2, 1/2]
  65. # ], dtype=np.float32)
  66. res = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
  67. res = res.todense().astype(np.float32)
  68. print('res:', res)
  69. print('expect:', expect)
  70. assert np.all(res - expect < 0.000001)
  71. def test_norm_adj_mat_one_node_type_dense_03():
  72. # adj_mat = torch.rand((10, 10))
  73. adj_mat = torch.tensor([
  74. [0, 1, 1, 0, 0],
  75. [1, 0, 1, 0, 1],
  76. [1, 1, 0, .5, .5],
  77. [0, 0, .5, 0, 1],
  78. [0, 1, .5, 1, 0]
  79. ])
  80. # adj_mat = (adj_mat > .5)
  81. adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
  82. adj_mat_ico = norm_adj_mat_one_node_type_dense(adj_mat)
  83. adj_mat_dec = adj_mat_dec.todense()
  84. adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
  85. print('adj_mat_dec:', adj_mat_dec)
  86. print('adj_mat_ico:', adj_mat_ico)
  87. assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)