IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add symmetry tests for decode.

master
Stanislaw Adaszewski 3 years ago
parent
commit
4b0f68d1ad
1 changed files with 75 additions and 0 deletions
  1. +75
    -0
      tests/icosagon/test_decode.py

+ 75
- 0
tests/icosagon/test_decode.py View File

@@ -84,3 +84,78 @@ def test_inner_product_decoder_01():
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_is_dedicom_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_dist_mult_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DistMultDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_bilinear_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = BilinearDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_inner_product_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = InnerProductDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)

Loading…
Cancel
Save