| @@ -1,8 +1,15 @@ | |||||
| from icosagon.data import Data | |||||
| from icosagon.data import Data, \ | |||||
| _equal | |||||
| from icosagon.model import Model | from icosagon.model import Model | ||||
| from icosagon.trainprep import PreparedData | |||||
| from icosagon.trainprep import PreparedData, \ | |||||
| PreparedRelationFamily, \ | |||||
| PreparedRelationType, \ | |||||
| TrainValTest, \ | |||||
| norm_adj_mat_one_node_type | |||||
| import torch | import torch | ||||
| import ast | |||||
| from icosagon.input import OneHotInputLayer | |||||
| from icosagon.convlayer import DecagonLayer | |||||
| from icosagon.declayer import DecodeLayer | |||||
| def _is_identity_function(f): | def _is_identity_function(f): | ||||
| @@ -33,3 +40,45 @@ def test_model_01(): | |||||
| assert isinstance(m.prep_d, PreparedData) | assert isinstance(m.prep_d, PreparedData) | ||||
| assert isinstance(m.seq, torch.nn.Sequential) | assert isinstance(m.seq, torch.nn.Sequential) | ||||
| assert isinstance(m.opt, torch.optim.Optimizer) | assert isinstance(m.opt, torch.optim.Optimizer) | ||||
| def test_model_02(): | |||||
| d = Data() | |||||
| d.add_node_type('Dummy', 10) | |||||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||||
| mat = torch.rand(10, 10).round().to_sparse() | |||||
| fam.add_relation_type('Dummy Rel', mat) | |||||
| m = Model(d, ratios=TrainValTest(1., 0., 0.)) | |||||
| assert isinstance(m.prep_d, PreparedData) | |||||
| assert isinstance(m.prep_d.relation_families, list) | |||||
| assert len(m.prep_d.relation_families) == 1 | |||||
| assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily) | |||||
| assert len(m.prep_d.relation_families[0].relation_types) == 1 | |||||
| assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType) | |||||
| assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None | |||||
| assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix, | |||||
| norm_adj_mat_one_node_type(mat))) | |||||
| assert isinstance(m.seq[0], OneHotInputLayer) | |||||
| assert isinstance(m.seq[1], DecagonLayer) | |||||
| assert isinstance(m.seq[2], DecagonLayer) | |||||
| assert isinstance(m.seq[3], DecodeLayer) | |||||
| assert len(m.seq) == 4 | |||||
| def test_model_03(): | |||||
| d = Data() | |||||
| d.add_node_type('Dummy', 10) | |||||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||||
| mat = torch.rand(10, 10).round().to_sparse() | |||||
| fam.add_relation_type('Dummy Rel', mat) | |||||
| m = Model(d, ratios=TrainValTest(1., 0., 0.)) | |||||
| state_dict = m.opt.state_dict() | |||||
| assert isinstance(state_dict, dict) | |||||
| # print(state_dict['param_groups']) | |||||
| # print(list(m.seq.parameters())) | |||||
| print(list(m.seq[1].parameters())) | |||||