| @@ -1,8 +1,15 @@ | |||
| from icosagon.data import Data | |||
| from icosagon.data import Data, \ | |||
| _equal | |||
| from icosagon.model import Model | |||
| from icosagon.trainprep import PreparedData | |||
| from icosagon.trainprep import PreparedData, \ | |||
| PreparedRelationFamily, \ | |||
| PreparedRelationType, \ | |||
| TrainValTest, \ | |||
| norm_adj_mat_one_node_type | |||
| import torch | |||
| import ast | |||
| from icosagon.input import OneHotInputLayer | |||
| from icosagon.convlayer import DecagonLayer | |||
| from icosagon.declayer import DecodeLayer | |||
| def _is_identity_function(f): | |||
| @@ -33,3 +40,45 @@ def test_model_01(): | |||
| assert isinstance(m.prep_d, PreparedData) | |||
| assert isinstance(m.seq, torch.nn.Sequential) | |||
| assert isinstance(m.opt, torch.optim.Optimizer) | |||
| def test_model_02(): | |||
| d = Data() | |||
| d.add_node_type('Dummy', 10) | |||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
| mat = torch.rand(10, 10).round().to_sparse() | |||
| fam.add_relation_type('Dummy Rel', mat) | |||
| m = Model(d, ratios=TrainValTest(1., 0., 0.)) | |||
| assert isinstance(m.prep_d, PreparedData) | |||
| assert isinstance(m.prep_d.relation_families, list) | |||
| assert len(m.prep_d.relation_families) == 1 | |||
| assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily) | |||
| assert len(m.prep_d.relation_families[0].relation_types) == 1 | |||
| assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType) | |||
| assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None | |||
| assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix, | |||
| norm_adj_mat_one_node_type(mat))) | |||
| assert isinstance(m.seq[0], OneHotInputLayer) | |||
| assert isinstance(m.seq[1], DecagonLayer) | |||
| assert isinstance(m.seq[2], DecagonLayer) | |||
| assert isinstance(m.seq[3], DecodeLayer) | |||
| assert len(m.seq) == 4 | |||
| def test_model_03(): | |||
| d = Data() | |||
| d.add_node_type('Dummy', 10) | |||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
| mat = torch.rand(10, 10).round().to_sparse() | |||
| fam.add_relation_type('Dummy Rel', mat) | |||
| m = Model(d, ratios=TrainValTest(1., 0., 0.)) | |||
| state_dict = m.opt.state_dict() | |||
| assert isinstance(state_dict, dict) | |||
| # print(state_dict['param_groups']) | |||
| # print(list(m.seq.parameters())) | |||
| print(list(m.seq[1].parameters())) | |||