IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
4.6KB

  1. from icosagon.decode import DEDICOMDecoder, \
  2. DistMultDecoder, \
  3. BilinearDecoder, \
  4. InnerProductDecoder
  5. import decagon_pytorch.decode.pairwise
  6. import torch
  7. def test_dedicom_decoder_01():
  8. repr_ = torch.rand(20, 32)
  9. dec_1 = DEDICOMDecoder(32, 7, keep_prob=1.,
  10. activation=torch.sigmoid)
  11. dec_2 = decagon_pytorch.decode.pairwise.DEDICOMDecoder(32, 7, drop_prob=0.,
  12. activation=torch.sigmoid)
  13. dec_2.global_interaction = dec_1.global_interaction
  14. dec_2.local_variation = dec_1.local_variation
  15. res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
  16. res_2 = dec_2(repr_, repr_)
  17. assert isinstance(res_1, list)
  18. assert isinstance(res_2, list)
  19. assert len(res_1) == len(res_2)
  20. for i in range(len(res_1)):
  21. assert torch.all(res_1[i] == res_2[i])
  22. def test_dist_mult_decoder_01():
  23. repr_ = torch.rand(20, 32)
  24. dec_1 = DistMultDecoder(32, 7, keep_prob=1.,
  25. activation=torch.sigmoid)
  26. dec_2 = decagon_pytorch.decode.pairwise.DistMultDecoder(32, 7, drop_prob=0.,
  27. activation=torch.sigmoid)
  28. dec_2.relation = dec_1.relation
  29. res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
  30. res_2 = dec_2(repr_, repr_)
  31. assert isinstance(res_1, list)
  32. assert isinstance(res_2, list)
  33. assert len(res_1) == len(res_2)
  34. for i in range(len(res_1)):
  35. assert torch.all(res_1[i] == res_2[i])
  36. def test_bilinear_decoder_01():
  37. repr_ = torch.rand(20, 32)
  38. dec_1 = BilinearDecoder(32, 7, keep_prob=1.,
  39. activation=torch.sigmoid)
  40. dec_2 = decagon_pytorch.decode.pairwise.BilinearDecoder(32, 7, drop_prob=0.,
  41. activation=torch.sigmoid)
  42. dec_2.relation = dec_1.relation
  43. res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
  44. res_2 = dec_2(repr_, repr_)
  45. assert isinstance(res_1, list)
  46. assert isinstance(res_2, list)
  47. assert len(res_1) == len(res_2)
  48. for i in range(len(res_1)):
  49. assert torch.all(res_1[i] == res_2[i])
  50. def test_inner_product_decoder_01():
  51. repr_ = torch.rand(20, 32)
  52. dec_1 = InnerProductDecoder(32, 7, keep_prob=1.,
  53. activation=torch.sigmoid)
  54. dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0.,
  55. activation=torch.sigmoid)
  56. res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
  57. res_2 = dec_2(repr_, repr_)
  58. assert isinstance(res_1, list)
  59. assert isinstance(res_2, list)
  60. assert len(res_1) == len(res_2)
  61. for i in range(len(res_1)):
  62. assert torch.all(res_1[i] == res_2[i])
  63. def test_is_dedicom_not_symmetric_01():
  64. repr_1 = torch.rand(20, 32)
  65. repr_2 = torch.rand(20, 32)
  66. dec = DEDICOMDecoder(32, 7, keep_prob=1.,
  67. activation=torch.sigmoid)
  68. res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
  69. res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
  70. assert isinstance(res_1, list)
  71. assert isinstance(res_2, list)
  72. assert len(res_1) == len(res_2)
  73. for i in range(len(res_1)):
  74. assert not torch.all(res_1[i] - res_2[i] < 0.000001)
  75. def test_is_dist_mult_symmetric_01():
  76. repr_1 = torch.rand(20, 32)
  77. repr_2 = torch.rand(20, 32)
  78. dec = DistMultDecoder(32, 7, keep_prob=1.,
  79. activation=torch.sigmoid)
  80. res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
  81. res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
  82. assert isinstance(res_1, list)
  83. assert isinstance(res_2, list)
  84. assert len(res_1) == len(res_2)
  85. for i in range(len(res_1)):
  86. assert torch.all(res_1[i] - res_2[i] < 0.000001)
  87. def test_is_bilinear_not_symmetric_01():
  88. repr_1 = torch.rand(20, 32)
  89. repr_2 = torch.rand(20, 32)
  90. dec = BilinearDecoder(32, 7, keep_prob=1.,
  91. activation=torch.sigmoid)
  92. res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
  93. res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
  94. assert isinstance(res_1, list)
  95. assert isinstance(res_2, list)
  96. assert len(res_1) == len(res_2)
  97. for i in range(len(res_1)):
  98. assert not torch.all(res_1[i] - res_2[i] < 0.000001)
  99. def test_is_inner_product_symmetric_01():
  100. repr_1 = torch.rand(20, 32)
  101. repr_2 = torch.rand(20, 32)
  102. dec = InnerProductDecoder(32, 7, keep_prob=1.,
  103. activation=torch.sigmoid)
  104. res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
  105. res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
  106. assert isinstance(res_1, list)
  107. assert isinstance(res_2, list)
  108. assert len(res_1) == len(res_2)
  109. for i in range(len(res_1)):
  110. assert torch.all(res_1[i] - res_2[i] < 0.000001)