IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 lines
2.6KB

  1. from .model import Model
  2. import torch
  3. from .batch import PredictionsBatch, \
  4. flatten_predictions, \
  5. gather_batch_indices
  6. from typing import Callable
  7. from types import FunctionType
  8. class TrainLoop(object):
  9. def __init__(self, model: Model, lr: float = 0.001,
  10. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  11. torch.nn.functional.binary_cross_entropy_with_logits,
  12. batch_size: int = 100, generator: torch.Generator = None) -> None:
  13. if not isinstance(model, Model):
  14. raise TypeError('model must be an instance of Model')
  15. lr = float(lr)
  16. if not isinstance(loss, FunctionType):
  17. raise TypeError('loss must be a function')
  18. batch_size = int(batch_size)
  19. if generator is not None and not isinstance(generator, torch.Generator):
  20. raise TypeError('generator must be an instance of torch.Generator')
  21. self.model = model
  22. self.lr = lr
  23. self.loss = loss
  24. self.batch_size = batch_size
  25. self.generator = generator or torch.default_generator
  26. self.opt = None
  27. self.build()
  28. def build(self) -> None:
  29. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  30. self.opt = opt
  31. def run_epoch(self):
  32. batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
  33. generator=self.generator)
  34. # pred = self.model(None)
  35. # n = len(list(iter(batch)))
  36. loss_sum = 0
  37. for indices in batch:
  38. self.opt.zero_grad()
  39. pred = self.model(None)
  40. pred = flatten_predictions(pred)
  41. # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  42. # seed = torch.rand(1).item()
  43. # rng_state = torch.get_rng_state()
  44. # torch.manual_seed(seed)
  45. #it = iter(batch)
  46. #torch.set_rng_state(rng_state)
  47. #for k in range(i):
  48. #_ = next(it)
  49. #(input, target) = next(it)
  50. (input, target) = gather_batch_indices(pred, indices)
  51. loss = self.loss(input, target)
  52. loss.backward()
  53. self.opt.step()
  54. loss_sum += loss.detach().cpu().item()
  55. return loss_sum
  56. def train(self, max_epochs):
  57. best_loss = None
  58. best_epoch = None
  59. for i in range(max_epochs):
  60. loss = self.run_epoch()
  61. if best_loss is None or loss < best_loss:
  62. best_loss = loss
  63. best_epoch = i
  64. return loss, best_loss, best_epoch