diff --git a/tests/icosagon/test_input.py b/tests/icosagon/test_input.py index 8bb5dc0..1ea676a 100644 --- a/tests/icosagon/test_input.py +++ b/tests/icosagon/test_input.py @@ -107,3 +107,16 @@ def test_one_hot_input_layer_04(): layer = OneHotInputLayer(d) s = repr(layer) assert s.startswith('Icosagon one-hot input layer') + + +def test_one_hot_input_layer_parameter_count_01(): + d = _some_data() + layer = OneHotInputLayer(d) + assert len(list(layer.parameters())) == 2 + + +def test_input_layer_parameter_count_01(): + d = _some_data() + for output_dim in [32, 64, 128]: + layer = InputLayer(d, output_dim) + assert len(list(layer.parameters())) == 2