| @@ -8,10 +8,15 @@ from types import FunctionType | |||||
| class TrainLoop(object): | class TrainLoop(object): | ||||
| def __init__(self, model: Model, lr: float = 0.001, | |||||
| def __init__( | |||||
| self, | |||||
| model: Model, | |||||
| lr: float = 0.001, | |||||
| loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | ||||
| torch.nn.functional.binary_cross_entropy_with_logits, | torch.nn.functional.binary_cross_entropy_with_logits, | ||||
| batch_size: int = 100, generator: torch.Generator = None) -> None: | |||||
| batch_size: int = 100, | |||||
| shuffle: bool = False, | |||||
| generator: torch.Generator = None) -> None: | |||||
| if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
| raise TypeError('model must be an instance of Model') | raise TypeError('model must be an instance of Model') | ||||
| @@ -30,6 +35,7 @@ class TrainLoop(object): | |||||
| self.lr = lr | self.lr = lr | ||||
| self.loss = loss | self.loss = loss | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.shuffle = shuffle | |||||
| self.generator = generator or torch.default_generator | self.generator = generator or torch.default_generator | ||||
| self.opt = None | self.opt = None | ||||
| @@ -42,7 +48,7 @@ class TrainLoop(object): | |||||
| def run_epoch(self): | def run_epoch(self): | ||||
| batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | ||||
| generator=self.generator) | |||||
| shuffle = self.shuffle, generator=self.generator) | |||||
| # pred = self.model(None) | # pred = self.model(None) | ||||
| # n = len(list(iter(batch))) | # n = len(list(iter(batch))) | ||||
| loss_sum = 0 | loss_sum = 0 | ||||