From f6e1024428066aa276cd6b9ba11e563a2f0f39ca Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 17 Jun 2020 15:04:19 +0200 Subject: [PATCH] Add test_model_05(). --- tests/icosagon/test_model.py | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index c1ad0b6..2d39163 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -102,3 +102,48 @@ def test_model_04(): assert len(list(m.seq[1].parameters())) == 2 assert len(list(m.seq[2].parameters())) == 2 assert len(list(m.seq[3].parameters())) == 3 + + +def test_model_05(): + d = Data() + d.add_node_type('Dummy', 10) + + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel 1', mat) + fam.add_relation_type('Dummy Rel 2', mat.clone()) + + fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel 2-1', mat) + fam.add_relation_type('Dummy Rel 2-2', mat.clone()) + + m = Model(d, ratios=TrainValTest(1., 0., 0.)) + + assert len(list(m.seq[0].parameters())) == 1 + assert len(list(m.seq[1].parameters())) == 4 + assert len(list(m.seq[2].parameters())) == 4 + assert len(list(m.seq[3].parameters())) == 6 + + +def test_model_05(): + d = Data() + d.add_node_type('Dummy', 10) + d.add_node_type('Foobar', 20) + + fam = d.add_relation_family('Dummy-Foobar', 0, 1, True) + mat = torch.rand(10, 20).round().to_sparse() + fam.add_relation_type('Dummy Rel 1', mat) + fam.add_relation_type('Dummy Rel 2', mat.clone()) + + fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel 2-1', mat) + fam.add_relation_type('Dummy Rel 2-2', mat.clone()) + + m = Model(d, ratios=TrainValTest(1., 0., 0.)) + + assert len(list(m.seq[0].parameters())) == 2 + assert len(list(m.seq[1].parameters())) == 6 + assert len(list(m.seq[2].parameters())) == 6 + assert len(list(m.seq[3].parameters())) == 6