From 90d40dbdbe25d02a69e37ab6691d1f909c461979 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 17 Jun 2020 12:03:44 +0200 Subject: [PATCH] Add test_model_02(). --- tests/icosagon/test_model.py | 55 ++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index 666960c..c645d32 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -1,8 +1,15 @@ -from icosagon.data import Data +from icosagon.data import Data, \ + _equal from icosagon.model import Model -from icosagon.trainprep import PreparedData +from icosagon.trainprep import PreparedData, \ + PreparedRelationFamily, \ + PreparedRelationType, \ + TrainValTest, \ + norm_adj_mat_one_node_type import torch -import ast +from icosagon.input import OneHotInputLayer +from icosagon.convlayer import DecagonLayer +from icosagon.declayer import DecodeLayer def _is_identity_function(f): @@ -33,3 +40,45 @@ def test_model_01(): assert isinstance(m.prep_d, PreparedData) assert isinstance(m.seq, torch.nn.Sequential) assert isinstance(m.opt, torch.optim.Optimizer) + + +def test_model_02(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel', mat) + + m = Model(d, ratios=TrainValTest(1., 0., 0.)) + + assert isinstance(m.prep_d, PreparedData) + assert isinstance(m.prep_d.relation_families, list) + assert len(m.prep_d.relation_families) == 1 + assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily) + assert len(m.prep_d.relation_families[0].relation_types) == 1 + assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType) + assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None + assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix, + norm_adj_mat_one_node_type(mat))) + + assert isinstance(m.seq[0], OneHotInputLayer) + assert isinstance(m.seq[1], DecagonLayer) + assert isinstance(m.seq[2], DecagonLayer) + assert isinstance(m.seq[3], DecodeLayer) + assert len(m.seq) == 4 + + +def test_model_03(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + mat = torch.rand(10, 10).round().to_sparse() + fam.add_relation_type('Dummy Rel', mat) + + m = Model(d, ratios=TrainValTest(1., 0., 0.)) + + state_dict = m.opt.state_dict() + assert isinstance(state_dict, dict) + # print(state_dict['param_groups']) + # print(list(m.seq.parameters())) + print(list(m.seq[1].parameters()))