From 4b0f68d1ad2b739d9c2634f20d99a8a360c2eddd Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 9 Jun 2020 13:34:58 +0200 Subject: [PATCH] Add symmetry tests for decode. --- tests/icosagon/test_decode.py | 75 +++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py index c0ec4c3..ee0eaf2 100644 --- a/tests/icosagon/test_decode.py +++ b/tests/icosagon/test_decode.py @@ -84,3 +84,78 @@ def test_inner_product_decoder_01(): for i in range(len(res_1)): assert torch.all(res_1[i] == res_2[i]) + + +def test_is_dedicom_not_symmetric_01(): + repr_1 = torch.rand(20, 32) + repr_2 = torch.rand(20, 32) + dec = DEDICOMDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] + res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] + + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert not torch.all(res_1[i] - res_2[i] < 0.000001) + + +def test_is_dist_mult_symmetric_01(): + repr_1 = torch.rand(20, 32) + repr_2 = torch.rand(20, 32) + dec = DistMultDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] + res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] + + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] - res_2[i] < 0.000001) + + +def test_is_bilinear_not_symmetric_01(): + repr_1 = torch.rand(20, 32) + repr_2 = torch.rand(20, 32) + dec = BilinearDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] + res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert not torch.all(res_1[i] - res_2[i] < 0.000001) + + +def test_is_inner_product_symmetric_01(): + repr_1 = torch.rand(20, 32) + repr_2 = torch.rand(20, 32) + dec = InnerProductDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] + res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] + + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] - res_2[i] < 0.000001)