IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Переглянути джерело

Add symmetry tests for decode.

master
Stanislaw Adaszewski 4 роки тому
джерело
коміт
4b0f68d1ad
1 змінених файлів з 75 додано та 0 видалено
  1. +75
    -0
      tests/icosagon/test_decode.py

+ 75
- 0
tests/icosagon/test_decode.py Переглянути файл

@@ -84,3 +84,78 @@ def test_inner_product_decoder_01():
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_is_dedicom_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_dist_mult_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DistMultDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_bilinear_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = BilinearDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_inner_product_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = InnerProductDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)

Завантаження…
Відмінити
Зберегти