| @@ -180,3 +180,25 @@ def test_model_07(): | |||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| m = Model(prep_d, dec_activation='x') | m = Model(prep_d, dec_activation='x') | ||||
| def test_model_08(): | |||||
| d = Data() | |||||
| d.add_node_type('Dummy', 10) | |||||
| d.add_node_type('Foobar', 20) | |||||
| fam = d.add_relation_family('Dummy-Foobar', 0, 1, True) | |||||
| mat = torch.rand(10, 20).round().to_sparse() | |||||
| fam.add_relation_type('Dummy Rel 1', mat) | |||||
| fam.add_relation_type('Dummy Rel 2', mat.clone()) | |||||
| fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False) | |||||
| mat = torch.rand(10, 10).round().to_sparse() | |||||
| fam.add_relation_type('Dummy Rel 2-1', mat) | |||||
| fam.add_relation_type('Dummy Rel 2-2', mat.clone()) | |||||
| prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) | |||||
| m = Model(prep_d) | |||||
| assert len(list(m.parameters())) == 20 | |||||