|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from .model import Model
- import torch
- from .batch import PredictionsBatch, \
- flatten_predictions, \
- gather_batch_indices
- from typing import Callable
- from types import FunctionType
-
-
- class TrainLoop(object):
- def __init__(self, model: Model, lr: float = 0.001,
- loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
- torch.nn.functional.binary_cross_entropy_with_logits,
- batch_size: int = 100, generator: torch.Generator = None) -> None:
-
- if not isinstance(model, Model):
- raise TypeError('model must be an instance of Model')
-
- lr = float(lr)
-
- if not isinstance(loss, FunctionType):
- raise TypeError('loss must be a function')
-
- batch_size = int(batch_size)
-
- if generator is not None and not isinstance(generator, torch.Generator):
- raise TypeError('generator must be an instance of torch.Generator')
-
- self.model = model
- self.lr = lr
- self.loss = loss
- self.batch_size = batch_size
- self.generator = generator or torch.default_generator
-
- self.opt = None
-
- self.build()
-
- def build(self) -> None:
- opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
- self.opt = opt
-
- def run_epoch(self):
- batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
- generator=self.generator)
- # pred = self.model(None)
- # n = len(list(iter(batch)))
- loss_sum = 0
- for i, indices in enumerate(batch):
- print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
- self.opt.zero_grad()
- pred = self.model(None)
- pred = flatten_predictions(pred)
- # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
- # seed = torch.rand(1).item()
- # rng_state = torch.get_rng_state()
- # torch.manual_seed(seed)
- #it = iter(batch)
- #torch.set_rng_state(rng_state)
- #for k in range(i):
- #_ = next(it)
- #(input, target) = next(it)
- (input, target) = gather_batch_indices(pred, indices)
- loss = self.loss(input, target)
- loss.backward()
- self.opt.step()
- loss_sum += loss.detach().cpu().item()
- return loss_sum
-
-
- def train(self, max_epochs):
- best_loss = None
- best_epoch = None
- for i in range(max_epochs):
- loss = self.run_epoch()
- if best_loss is None or loss < best_loss:
- best_loss = loss
- best_epoch = i
- return loss, best_loss, best_epoch
|