diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py index 96ca3d7..b8c9cea 100644 --- a/tests/icosagon/test_decode.py +++ b/tests/icosagon/test_decode.py @@ -210,3 +210,31 @@ def test_empty_inner_product_decoder_01(): for i in range(len(res)): assert res[i].shape == (0,) + + +def test_dedicom_decoder_parameter_count_01(): + dec = DEDICOMDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 8 + + +def test_dist_mult_decoder_parameter_count_01(): + dec = DistMultDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 7 + + +def test_bilinear_decoder_parameter_count_01(): + dec = BilinearDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 7 + + +def test_inner_product_decoder_parameter_count_01(): + dec = InnerProductDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + + assert len(list(dec.parameters())) == 0