from .model import Model import torch from .batch import PredictionsBatch from typing import Callable from types import FunctionType class TrainLoop(object): def __init__(self, model: Model, lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, batch_size: int = 100) -> None: if not isinstance(model, Model): raise TypeError('model must be an instance of Model') lr = float(lr) if not isinstance(loss, FunctionType): raise TypeError('loss must be a function') batch_size = int(batch_size) self.model = model self.lr = lr self.loss = loss self.batch_size = batch_size self.opt = None self.build() def build(self) -> None: opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.opt = opt def run_epoch(self): pred = self.model(None) batch = PredictionsBatch(pred, batch_size=self.batch_size) n = len(list(iter(batch))) loss_sum = 0 for i in range(n): self.opt.zero_grad() pred = self.seq(None) batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) seed = torch.rand(1).item() rng_state = torch.get_rng_state() torch.manual_seed(seed) it = iter(batch) torch.set_rng_state(rng_state) for k in range(i): _ = next(it) (input, target) = next(it) loss = self.loss(input, target) loss.backward() self.opt.step() loss_sum += loss.detach().cpu().item() return loss_sum def train(self, max_epochs): best_loss = None best_epoch = None for i in range(max_epochs): loss = self.run_epoch() if best_loss is None or loss < best_loss: best_loss = loss best_epoch = i return loss, best_loss, best_epoch