|
|
@@ -1,8 +1,15 @@ |
|
|
|
from icosagon.data import Data
|
|
|
|
from icosagon.data import Data, \
|
|
|
|
_equal
|
|
|
|
from icosagon.model import Model
|
|
|
|
from icosagon.trainprep import PreparedData
|
|
|
|
from icosagon.trainprep import PreparedData, \
|
|
|
|
PreparedRelationFamily, \
|
|
|
|
PreparedRelationType, \
|
|
|
|
TrainValTest, \
|
|
|
|
norm_adj_mat_one_node_type
|
|
|
|
import torch
|
|
|
|
import ast
|
|
|
|
from icosagon.input import OneHotInputLayer
|
|
|
|
from icosagon.convlayer import DecagonLayer
|
|
|
|
from icosagon.declayer import DecodeLayer
|
|
|
|
|
|
|
|
|
|
|
|
def _is_identity_function(f):
|
|
|
@@ -33,3 +40,45 @@ def test_model_01(): |
|
|
|
assert isinstance(m.prep_d, PreparedData)
|
|
|
|
assert isinstance(m.seq, torch.nn.Sequential)
|
|
|
|
assert isinstance(m.opt, torch.optim.Optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_02():
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
fam.add_relation_type('Dummy Rel', mat)
|
|
|
|
|
|
|
|
m = Model(d, ratios=TrainValTest(1., 0., 0.))
|
|
|
|
|
|
|
|
assert isinstance(m.prep_d, PreparedData)
|
|
|
|
assert isinstance(m.prep_d.relation_families, list)
|
|
|
|
assert len(m.prep_d.relation_families) == 1
|
|
|
|
assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
|
|
|
|
assert len(m.prep_d.relation_families[0].relation_types) == 1
|
|
|
|
assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
|
|
|
|
assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
|
|
|
|
assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
|
|
|
|
norm_adj_mat_one_node_type(mat)))
|
|
|
|
|
|
|
|
assert isinstance(m.seq[0], OneHotInputLayer)
|
|
|
|
assert isinstance(m.seq[1], DecagonLayer)
|
|
|
|
assert isinstance(m.seq[2], DecagonLayer)
|
|
|
|
assert isinstance(m.seq[3], DecodeLayer)
|
|
|
|
assert len(m.seq) == 4
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_03():
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
fam.add_relation_type('Dummy Rel', mat)
|
|
|
|
|
|
|
|
m = Model(d, ratios=TrainValTest(1., 0., 0.))
|
|
|
|
|
|
|
|
state_dict = m.opt.state_dict()
|
|
|
|
assert isinstance(state_dict, dict)
|
|
|
|
# print(state_dict['param_groups'])
|
|
|
|
# print(list(m.seq.parameters()))
|
|
|
|
print(list(m.seq[1].parameters()))
|