|
@@ -196,3 +196,11 @@ def test_graph_conv_parameter_count_01(): |
|
|
conv = GraphConv(20, 20, adj_mat)
|
|
|
conv = GraphConv(20, 20, adj_mat)
|
|
|
|
|
|
|
|
|
assert len(list(conv.parameters())) == 1
|
|
|
assert len(list(conv.parameters())) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dropout_graph_conv_activation_parameter_count_01():
|
|
|
|
|
|
adj_mat = torch.rand((10, 20)).round()
|
|
|
|
|
|
|
|
|
|
|
|
conv = DropoutGraphConvActivation(20, 20, adj_mat)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(conv.parameters())) == 1
|