From b9e7e395cb78c3ad0a0abc4d020c769f27612ac5 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 16 Jul 2020 20:23:26 +0200 Subject: [PATCH] Add shuffle param to TrainLoop. --- src/icosagon/trainloop.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py index 3392a40..098baba 100644 --- a/src/icosagon/trainloop.py +++ b/src/icosagon/trainloop.py @@ -8,10 +8,15 @@ from types import FunctionType class TrainLoop(object): - def __init__(self, model: Model, lr: float = 0.001, + def __init__( + self, + model: Model, + lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, - batch_size: int = 100, generator: torch.Generator = None) -> None: + batch_size: int = 100, + shuffle: bool = False, + generator: torch.Generator = None) -> None: if not isinstance(model, Model): raise TypeError('model must be an instance of Model') @@ -30,6 +35,7 @@ class TrainLoop(object): self.lr = lr self.loss = loss self.batch_size = batch_size + self.shuffle = shuffle self.generator = generator or torch.default_generator self.opt = None @@ -42,7 +48,7 @@ class TrainLoop(object): def run_epoch(self): batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, - generator=self.generator) + shuffle = self.shuffle, generator=self.generator) # pred = self.model(None) # n = len(list(iter(batch))) loss_sum = 0