|
@@ -37,11 +37,7 @@ def test_model_01(): |
|
|
assert _is_identity_function(m.rel_activation)
|
|
|
assert _is_identity_function(m.rel_activation)
|
|
|
assert m.layer_activation == torch.nn.functional.relu
|
|
|
assert m.layer_activation == torch.nn.functional.relu
|
|
|
assert _is_identity_function(m.dec_activation)
|
|
|
assert _is_identity_function(m.dec_activation)
|
|
|
assert m.lr == 0.001
|
|
|
|
|
|
assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits
|
|
|
|
|
|
assert m.batch_size == 100
|
|
|
|
|
|
assert isinstance(m.seq, torch.nn.Sequential)
|
|
|
assert isinstance(m.seq, torch.nn.Sequential)
|
|
|
assert isinstance(m.opt, torch.optim.Optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_02():
|
|
|
def test_model_02():
|
|
@@ -83,15 +79,10 @@ def test_model_03(): |
|
|
|
|
|
|
|
|
m = Model(prep_d)
|
|
|
m = Model(prep_d)
|
|
|
|
|
|
|
|
|
state_dict = m.opt.state_dict()
|
|
|
|
|
|
assert isinstance(state_dict, dict)
|
|
|
|
|
|
# print(state_dict['param_groups'])
|
|
|
|
|
|
# print(list(m.seq.parameters()))
|
|
|
|
|
|
assert len(list(m.seq[0].parameters())) == 1
|
|
|
assert len(list(m.seq[0].parameters())) == 1
|
|
|
assert len(list(m.seq[1].parameters())) == 1
|
|
|
assert len(list(m.seq[1].parameters())) == 1
|
|
|
assert len(list(m.seq[2].parameters())) == 1
|
|
|
assert len(list(m.seq[2].parameters())) == 1
|
|
|
assert len(list(m.seq[3].parameters())) == 2
|
|
|
assert len(list(m.seq[3].parameters())) == 2
|
|
|
# print(list(m.seq[1].parameters()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_04():
|
|
|
def test_model_04():
|
|
@@ -189,25 +180,3 @@ def test_model_07(): |
|
|
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
with pytest.raises(TypeError):
|
|
|
m = Model(prep_d, dec_activation='x')
|
|
|
m = Model(prep_d, dec_activation='x')
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
m = Model(prep_d, lr='x')
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
|
m = Model(prep_d, loss=1)
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
|
m = Model(prep_d, batch_size='x')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_08():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
|
|
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
|
|
|
|
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
|
|
|
|
m = Model(prep_d)
|
|
|
|
|
|
|
|
|
|
|
|
m.run_epoch()
|
|
|
|