|
@@ -102,3 +102,48 @@ def test_model_04(): |
|
|
assert len(list(m.seq[1].parameters())) == 2
|
|
|
assert len(list(m.seq[1].parameters())) == 2
|
|
|
assert len(list(m.seq[2].parameters())) == 2
|
|
|
assert len(list(m.seq[2].parameters())) == 2
|
|
|
assert len(list(m.seq[3].parameters())) == 3
|
|
|
assert len(list(m.seq[3].parameters())) == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_05():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
|
|
mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 1', mat)
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2', mat.clone())
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
|
|
|
|
|
|
mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2-1', mat)
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2-2', mat.clone())
|
|
|
|
|
|
|
|
|
|
|
|
m = Model(d, ratios=TrainValTest(1., 0., 0.))
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(m.seq[0].parameters())) == 1
|
|
|
|
|
|
assert len(list(m.seq[1].parameters())) == 4
|
|
|
|
|
|
assert len(list(m.seq[2].parameters())) == 4
|
|
|
|
|
|
assert len(list(m.seq[3].parameters())) == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_05():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
|
|
d.add_node_type('Foobar', 20)
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
|
|
|
|
|
|
mat = torch.rand(10, 20).round().to_sparse()
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 1', mat)
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2', mat.clone())
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
|
|
|
|
|
|
mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2-1', mat)
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel 2-2', mat.clone())
|
|
|
|
|
|
|
|
|
|
|
|
m = Model(d, ratios=TrainValTest(1., 0., 0.))
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(m.seq[0].parameters())) == 2
|
|
|
|
|
|
assert len(list(m.seq[1].parameters())) == 6
|
|
|
|
|
|
assert len(list(m.seq[2].parameters())) == 6
|
|
|
|
|
|
assert len(list(m.seq[3].parameters())) == 6
|