|
- from .model import Model
- import torch
- from .batch import PredictionsBatch
- from typing import Callable
- from types import FunctionType
-
-
- class TrainLoop(object):
- def __init__(self, model: Model, lr: float = 0.001,
- loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
- torch.nn.functional.binary_cross_entropy_with_logits,
- batch_size: int = 100) -> None:
-
- if not isinstance(model, Model):
- raise TypeError('model must be an instance of Model')
-
- lr = float(lr)
-
- if not isinstance(loss, FunctionType):
- raise TypeError('loss must be a function')
-
- batch_size = int(batch_size)
-
- self.model = model
- self.lr = lr
- self.loss = loss
- self.batch_size = batch_size
-
- self.opt = None
-
- self.build()
-
- def build(self) -> None:
- opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
- self.opt = opt
-
- def run_epoch(self):
- pred = self.model(None)
- batch = PredictionsBatch(pred, batch_size=self.batch_size)
- n = len(list(iter(batch)))
- loss_sum = 0
- for i in range(n):
- self.opt.zero_grad()
- pred = self.seq(None)
- batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
- seed = torch.rand(1).item()
- rng_state = torch.get_rng_state()
- torch.manual_seed(seed)
- it = iter(batch)
- torch.set_rng_state(rng_state)
- for k in range(i):
- _ = next(it)
- (input, target) = next(it)
- loss = self.loss(input, target)
- loss.backward()
- self.opt.step()
- loss_sum += loss.detach().cpu().item()
- return loss_sum
-
-
- def train(self, max_epochs):
- best_loss = None
- best_epoch = None
- for i in range(max_epochs):
- loss = self.run_epoch()
- if best_loss is None or loss < best_loss:
- best_loss = loss
- best_epoch = i
- return loss, best_loss, best_epoch
|