from .model import Model import torch from .batch import PredictionsBatch, \ flatten_predictions, \ gather_batch_indices from typing import Callable from types import FunctionType import time class TrainLoop(object): def __init__( self, model: Model, lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, batch_size: int = 100, shuffle: bool = False, generator: torch.Generator = None) -> None: if not isinstance(model, Model): raise TypeError('model must be an instance of Model') lr = float(lr) if not isinstance(loss, FunctionType): raise TypeError('loss must be a function') batch_size = int(batch_size) if generator is not None and not isinstance(generator, torch.Generator): raise TypeError('generator must be an instance of torch.Generator') self.model = model self.lr = lr self.loss = loss self.batch_size = batch_size self.shuffle = shuffle self.generator = generator or torch.default_generator self.opt = None self.build() def build(self) -> None: opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.opt = opt def run_epoch(self): batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, shuffle = self.shuffle, generator=self.generator) # pred = self.model(None) # n = len(list(iter(batch))) loss_sum = 0 for i, indices in enumerate(batch): print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) t = time.time() self.opt.zero_grad() print('zero_grad() took:', time.time() - t) t = time.time() pred = self.model(None) print('model() took:', time.time() - t) t = time.time() pred = flatten_predictions(pred) print('flatten_predictions() took:', time.time() - t) # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) # seed = torch.rand(1).item() # rng_state = torch.get_rng_state() # torch.manual_seed(seed) #it = iter(batch) #torch.set_rng_state(rng_state) #for k in range(i): #_ = next(it) #(input, target) = next(it) t = time.time() (input, target) = gather_batch_indices(pred, indices) print('gather_batch_indices() took:', time.time() - t) t = time.time() loss = self.loss(input, target) print('loss() took:', time.time() - t) t = time.time() loss.backward() print('backward() took:', time.time() - t) t = time.time() self.opt.step() print('step() took:', time.time() - t) loss_sum += loss.detach().cpu().item() return loss_sum def train(self, max_epochs): best_loss = None best_epoch = None for i in range(max_epochs): loss = self.run_epoch() if best_loss is None or loss < best_loss: best_loss = loss best_epoch = i return loss, best_loss, best_epoch