IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
소스 검색

Add symmetry tests for decode.

master
Stanislaw Adaszewski 3 년 전
부모
커밋
4b0f68d1ad
1개의 변경된 파일75개의 추가작업 그리고 0개의 파일을 삭제
  1. +75
    -0
      tests/icosagon/test_decode.py

+ 75
- 0
tests/icosagon/test_decode.py 파일 보기

@@ -84,3 +84,78 @@ def test_inner_product_decoder_01():
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_is_dedicom_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_dist_mult_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = DistMultDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_bilinear_not_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = BilinearDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
def test_is_inner_product_symmetric_01():
repr_1 = torch.rand(20, 32)
repr_2 = torch.rand(20, 32)
dec = InnerProductDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] - res_2[i] < 0.000001)

불러오는 중...
취소
저장