| @@ -86,3 +86,19 @@ def test_model_03(): | |||||
| assert len(list(m.seq[2].parameters())) == 1 | assert len(list(m.seq[2].parameters())) == 1 | ||||
| assert len(list(m.seq[3].parameters())) == 2 | assert len(list(m.seq[3].parameters())) == 2 | ||||
| # print(list(m.seq[1].parameters())) | # print(list(m.seq[1].parameters())) | ||||
| def test_model_04(): | |||||
| d = Data() | |||||
| d.add_node_type('Dummy', 10) | |||||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||||
| mat = torch.rand(10, 10).round().to_sparse() | |||||
| fam.add_relation_type('Dummy Rel 1', mat) | |||||
| fam.add_relation_type('Dummy Rel 2', mat.clone()) | |||||
| m = Model(d, ratios=TrainValTest(1., 0., 0.)) | |||||
| assert len(list(m.seq[0].parameters())) == 1 | |||||
| assert len(list(m.seq[1].parameters())) == 2 | |||||
| assert len(list(m.seq[2].parameters())) == 2 | |||||
| assert len(list(m.seq[3].parameters())) == 3 | |||||