From 77c52d8543ed71597ca9f034a0e5399bf78f120f Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sat, 20 Jun 2020 15:06:08 +0200 Subject: [PATCH] Add test_[dedicom,dist_mult,bilinear,inner_product]_decoder_parameter_count_01(). --- tests/icosagon/test_decode.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py index 96ca3d7..b8c9cea 100644 --- a/tests/icosagon/test_decode.py +++ b/tests/icosagon/test_decode.py @@ -210,3 +210,31 @@ def test_empty_inner_product_decoder_01(): for i in range(len(res)): assert res[i].shape == (0,) + + +def test_dedicom_decoder_parameter_count_01(): + dec = DEDICOMDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 8 + + +def test_dist_mult_decoder_parameter_count_01(): + dec = DistMultDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 7 + + +def test_bilinear_decoder_parameter_count_01(): + dec = BilinearDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 7 + + +def test_inner_product_decoder_parameter_count_01(): + dec = InnerProductDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 0