IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_model_05().

master
Stanislaw Adaszewski 3 years ago
parent
commit
f6e1024428
1 changed files with 45 additions and 0 deletions
  1. +45
    -0
      tests/icosagon/test_model.py

+ 45
- 0
tests/icosagon/test_model.py View File

@@ -102,3 +102,48 @@ def test_model_04():
assert len(list(m.seq[1].parameters())) == 2
assert len(list(m.seq[2].parameters())) == 2
assert len(list(m.seq[3].parameters())) == 3
def test_model_05():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
mat = torch.rand(10, 10).round().to_sparse()
fam.add_relation_type('Dummy Rel 1', mat)
fam.add_relation_type('Dummy Rel 2', mat.clone())
fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
mat = torch.rand(10, 10).round().to_sparse()
fam.add_relation_type('Dummy Rel 2-1', mat)
fam.add_relation_type('Dummy Rel 2-2', mat.clone())
m = Model(d, ratios=TrainValTest(1., 0., 0.))
assert len(list(m.seq[0].parameters())) == 1
assert len(list(m.seq[1].parameters())) == 4
assert len(list(m.seq[2].parameters())) == 4
assert len(list(m.seq[3].parameters())) == 6
def test_model_05():
d = Data()
d.add_node_type('Dummy', 10)
d.add_node_type('Foobar', 20)
fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
mat = torch.rand(10, 20).round().to_sparse()
fam.add_relation_type('Dummy Rel 1', mat)
fam.add_relation_type('Dummy Rel 2', mat.clone())
fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
mat = torch.rand(10, 10).round().to_sparse()
fam.add_relation_type('Dummy Rel 2-1', mat)
fam.add_relation_type('Dummy Rel 2-2', mat.clone())
m = Model(d, ratios=TrainValTest(1., 0., 0.))
assert len(list(m.seq[0].parameters())) == 2
assert len(list(m.seq[1].parameters())) == 6
assert len(list(m.seq[2].parameters())) == 6
assert len(list(m.seq[3].parameters())) == 6

Loading…
Cancel
Save