diff --git a/src/icosagon/model.py b/src/icosagon/model.py index 5d384f0..321e600 100644 --- a/src/icosagon/model.py +++ b/src/icosagon/model.py @@ -99,13 +99,13 @@ class Model(object): def run_epoch(self): pred = self.seq(None) - batch = PredictionsBatch(pred, self.batch_size) + batch = PredictionsBatch(pred, batch_size=self.batch_size) n = len(list(iter(batch))) loss_sum = 0 - for i in range(n - 1): + for i in range(n): self.opt.zero_grad() pred = self.seq(None) - batch = PredictionsBatch(pred, self.batch_size, shuffle=True) + batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) seed = torch.rand(1).item() rng_state = torch.get_rng_state() torch.manual_seed(seed) @@ -116,7 +116,7 @@ class Model(object): (input, target) = next(it) loss = self.loss(input, target) loss.backward() - self.opt.optimize() + self.opt.step() loss_sum += loss.detach().cpu().item() return loss_sum diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index 3084031..fce7460 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -185,3 +185,14 @@ def test_model_06(): with pytest.raises(ValueError): m = Model(d, batch_size='x') + + +def test_model_07(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) + + m = Model(d) + + m.run_epoch()