IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add test_model_02().

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
90d40dbdbe
共有 1 个文件被更改,包括 52 次插入3 次删除
  1. +52
    -3
      tests/icosagon/test_model.py

+ 52
- 3
tests/icosagon/test_model.py 查看文件

@@ -1,8 +1,15 @@
from icosagon.data import Data
from icosagon.data import Data, \
_equal
from icosagon.model import Model
from icosagon.trainprep import PreparedData
from icosagon.trainprep import PreparedData, \
PreparedRelationFamily, \
PreparedRelationType, \
TrainValTest, \
norm_adj_mat_one_node_type
import torch
import ast
from icosagon.input import OneHotInputLayer
from icosagon.convlayer import DecagonLayer
from icosagon.declayer import DecodeLayer
def _is_identity_function(f):
@@ -33,3 +40,45 @@ def test_model_01():
assert isinstance(m.prep_d, PreparedData)
assert isinstance(m.seq, torch.nn.Sequential)
assert isinstance(m.opt, torch.optim.Optimizer)
def test_model_02():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
mat = torch.rand(10, 10).round().to_sparse()
fam.add_relation_type('Dummy Rel', mat)
m = Model(d, ratios=TrainValTest(1., 0., 0.))
assert isinstance(m.prep_d, PreparedData)
assert isinstance(m.prep_d.relation_families, list)
assert len(m.prep_d.relation_families) == 1
assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
assert len(m.prep_d.relation_families[0].relation_types) == 1
assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
norm_adj_mat_one_node_type(mat)))
assert isinstance(m.seq[0], OneHotInputLayer)
assert isinstance(m.seq[1], DecagonLayer)
assert isinstance(m.seq[2], DecagonLayer)
assert isinstance(m.seq[3], DecodeLayer)
assert len(m.seq) == 4
def test_model_03():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
mat = torch.rand(10, 10).round().to_sparse()
fam.add_relation_type('Dummy Rel', mat)
m = Model(d, ratios=TrainValTest(1., 0., 0.))
state_dict = m.opt.state_dict()
assert isinstance(state_dict, dict)
# print(state_dict['param_groups'])
# print(list(m.seq.parameters()))
print(list(m.seq[1].parameters()))

正在加载...
取消
保存