diff --git a/tests/icosagon/test_convolve.py b/tests/icosagon/test_convolve.py index 79fb2ac..d4df6ea 100644 --- a/tests/icosagon/test_convolve.py +++ b/tests/icosagon/test_convolve.py @@ -196,3 +196,11 @@ def test_graph_conv_parameter_count_01(): conv = GraphConv(20, 20, adj_mat) assert len(list(conv.parameters())) == 1 + + +def test_dropout_graph_conv_activation_parameter_count_01(): + adj_mat = torch.rand((10, 20)).round() + + conv = DropoutGraphConvActivation(20, 20, adj_mat) + + assert len(list(conv.parameters())) == 1