IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainloop.py 2.1KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from .model import Model
  2. import torch
  3. from .batch import PredictionsBatch
  4. from typing import Callable
  5. from types import FunctionType
  6. class TrainLoop(object):
  7. def __init__(self, model: Model, lr: float = 0.001,
  8. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  9. torch.nn.functional.binary_cross_entropy_with_logits,
  10. batch_size: int = 100) -> None:
  11. if not isinstance(model, Model):
  12. raise TypeError('model must be an instance of Model')
  13. lr = float(lr)
  14. if not isinstance(loss, FunctionType):
  15. raise TypeError('loss must be a function')
  16. batch_size = int(batch_size)
  17. self.model = model
  18. self.lr = lr
  19. self.loss = loss
  20. self.batch_size = batch_size
  21. self.opt = None
  22. self.build()
  23. def build(self) -> None:
  24. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  25. self.opt = opt
  26. def run_epoch(self):
  27. pred = self.model(None)
  28. batch = PredictionsBatch(pred, batch_size=self.batch_size)
  29. n = len(list(iter(batch)))
  30. loss_sum = 0
  31. for i in range(n):
  32. self.opt.zero_grad()
  33. pred = self.seq(None)
  34. batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  35. seed = torch.rand(1).item()
  36. rng_state = torch.get_rng_state()
  37. torch.manual_seed(seed)
  38. it = iter(batch)
  39. torch.set_rng_state(rng_state)
  40. for k in range(i):
  41. _ = next(it)
  42. (input, target) = next(it)
  43. loss = self.loss(input, target)
  44. loss.backward()
  45. self.opt.step()
  46. loss_sum += loss.detach().cpu().item()
  47. return loss_sum
  48. def train(self, max_epochs):
  49. best_loss = None
  50. best_epoch = None
  51. for i in range(max_epochs):
  52. loss = self.run_epoch()
  53. if best_loss is None or loss < best_loss:
  54. best_loss = loss
  55. best_epoch = i
  56. return loss, best_loss, best_epoch