diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index 5c122d9..5e5d482 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -180,3 +180,25 @@ def test_model_07(): with pytest.raises(TypeError): m = Model(prep_d, dec_activation='x') + + +def test_model_08(): + d = Data() + d.add_node_type('Dummy', 10) + d.add_node_type('Foobar', 20) + + fam = d.add_relation_family('Dummy-Foobar', 0, 1, True) + mat = torch.rand(10, 20).round().to_sparse() + fam.add_relation_type('Dummy Rel 1', mat) + fam.add_relation_type('Dummy Rel 2', mat.clone()) + + fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel 2-1', mat) + fam.add_relation_type('Dummy Rel 2-2', mat.clone()) + + prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) + + m = Model(prep_d) + + assert len(list(m.parameters())) == 20