diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index 216b1ed..732fd45 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -232,3 +232,56 @@ def test_decode_layer_05(): # assert isinstance(rel_pred.edges_neg, TrainValTest) # assert isinstance(rel_pred.edges_back_pos, TrainValTest) # assert isinstance(rel_pred.edges_back_neg, TrainValTest) + + +def test_decode_layer_parameter_count_01(): + d = Data() + d.add_node_type('Dummy', 100) + + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Relation 1', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., + activation=lambda x: x) + + assert len(list(dec.parameters())) == 2 + + +def test_decode_layer_parameter_count_02(): + d = Data() + d.add_node_type('Dummy', 100) + + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Relation 1', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + fam.add_relation_type('Dummy Relation 2', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., + activation=lambda x: x) + + assert len(list(dec.parameters())) == 3 + + +def test_decode_layer_parameter_count_03(): + d = Data() + d.add_node_type('Dummy', 100) + + for _ in range(2): + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Relation 1', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + fam.add_relation_type('Dummy Relation 2', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., + activation=lambda x: x) + + assert len(list(dec.parameters())) == 6