diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py index 051019e..2ce8240 100644 --- a/src/icosagon/trainloop.py +++ b/src/icosagon/trainloop.py @@ -41,7 +41,7 @@ class TrainLoop(object): loss_sum = 0 for i in range(n): self.opt.zero_grad() - pred = self.seq(None) + pred = self.model(None) batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) seed = torch.rand(1).item() rng_state = torch.get_rng_state() diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index 960e34d..5c122d9 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -37,11 +37,7 @@ def test_model_01(): assert _is_identity_function(m.rel_activation) assert m.layer_activation == torch.nn.functional.relu assert _is_identity_function(m.dec_activation) - assert m.lr == 0.001 - assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits - assert m.batch_size == 100 assert isinstance(m.seq, torch.nn.Sequential) - assert isinstance(m.opt, torch.optim.Optimizer) def test_model_02(): @@ -83,15 +79,10 @@ def test_model_03(): m = Model(prep_d) - state_dict = m.opt.state_dict() - assert isinstance(state_dict, dict) - # print(state_dict['param_groups']) - # print(list(m.seq.parameters())) assert len(list(m.seq[0].parameters())) == 1 assert len(list(m.seq[1].parameters())) == 1 assert len(list(m.seq[2].parameters())) == 1 assert len(list(m.seq[3].parameters())) == 2 - # print(list(m.seq[1].parameters())) def test_model_04(): @@ -189,25 +180,3 @@ def test_model_07(): with pytest.raises(TypeError): m = Model(prep_d, dec_activation='x') - - with pytest.raises(ValueError): - m = Model(prep_d, lr='x') - - with pytest.raises(TypeError): - m = Model(prep_d, loss=1) - - with pytest.raises(ValueError): - m = Model(prep_d, batch_size='x') - - -def test_model_08(): - d = Data() - d.add_node_type('Dummy', 10) - fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) - fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) - - prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) - - m = Model(prep_d) - - m.run_epoch() diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py index 2476c9c..04956c6 100644 --- a/tests/icosagon/test_trainloop.py +++ b/tests/icosagon/test_trainloop.py @@ -22,3 +22,18 @@ def test_train_loop_01(): assert loop.lr == 0.001 assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits assert loop.batch_size == 100 + + +def test_train_loop_02(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + m = Model(prep_d) + + loop = TrainLoop(m) + + loop.run_epoch()