IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.5KB

  1. from .model import Model
  2. import torch
  3. from .batch import PredictionsBatch, \
  4. flatten_predictions, \
  5. gather_batch_indices
  6. from typing import Callable
  7. from types import FunctionType
  8. import time
  9. class TrainLoop(object):
  10. def __init__(
  11. self,
  12. model: Model,
  13. lr: float = 0.001,
  14. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  15. torch.nn.functional.binary_cross_entropy_with_logits,
  16. batch_size: int = 100,
  17. shuffle: bool = False,
  18. generator: torch.Generator = None) -> None:
  19. if not isinstance(model, Model):
  20. raise TypeError('model must be an instance of Model')
  21. lr = float(lr)
  22. if not isinstance(loss, FunctionType):
  23. raise TypeError('loss must be a function')
  24. batch_size = int(batch_size)
  25. if generator is not None and not isinstance(generator, torch.Generator):
  26. raise TypeError('generator must be an instance of torch.Generator')
  27. self.model = model
  28. self.lr = lr
  29. self.loss = loss
  30. self.batch_size = batch_size
  31. self.shuffle = shuffle
  32. self.generator = generator or torch.default_generator
  33. self.opt = None
  34. self.build()
  35. def build(self) -> None:
  36. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  37. self.opt = opt
  38. def run_epoch(self):
  39. batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
  40. shuffle = self.shuffle, generator=self.generator)
  41. # pred = self.model(None)
  42. # n = len(list(iter(batch)))
  43. loss_sum = 0
  44. for i, indices in enumerate(batch):
  45. print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
  46. t = time.time()
  47. self.opt.zero_grad()
  48. print('zero_grad() took:', time.time() - t)
  49. t = time.time()
  50. pred = self.model(None)
  51. print('model() took:', time.time() - t)
  52. t = time.time()
  53. pred = flatten_predictions(pred)
  54. print('flatten_predictions() took:', time.time() - t)
  55. # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  56. # seed = torch.rand(1).item()
  57. # rng_state = torch.get_rng_state()
  58. # torch.manual_seed(seed)
  59. #it = iter(batch)
  60. #torch.set_rng_state(rng_state)
  61. #for k in range(i):
  62. #_ = next(it)
  63. #(input, target) = next(it)
  64. t = time.time()
  65. (input, target) = gather_batch_indices(pred, indices)
  66. print('gather_batch_indices() took:', time.time() - t)
  67. t = time.time()
  68. loss = self.loss(input, target)
  69. print('loss() took:', time.time() - t)
  70. t = time.time()
  71. loss.backward()
  72. print('backward() took:', time.time() - t)
  73. t = time.time()
  74. self.opt.step()
  75. print('step() took:', time.time() - t)
  76. loss_sum += loss.detach().cpu().item()
  77. return loss_sum
  78. def train(self, max_epochs):
  79. best_loss = None
  80. best_epoch = None
  81. for i in range(max_epochs):
  82. loss = self.run_epoch()
  83. if best_loss is None or loss < best_loss:
  84. best_loss = loss
  85. best_epoch = i
  86. return loss, best_loss, best_epoch