IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Add symmetry tests for decode.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
4b0f68d1ad
共有 1 個檔案被更改,包括 75 行新增0 行删除
  1. +75
    -0
      tests/icosagon/test_decode.py

+ 75
- 0
tests/icosagon/test_decode.py 查看文件

@@ -84,3 +84,78 @@ def test_inner_product_decoder_01():
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_is_dedicom_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_dist_mult_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DistMultDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_bilinear_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = BilinearDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_inner_product_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = InnerProductDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)

Loading…
取消
儲存