IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

86 строки
2.9KB

  1. from .model import Model
  2. import torch
  3. from .batch import PredictionsBatch, \
  4. flatten_predictions, \
  5. gather_batch_indices
  6. from typing import Callable
  7. from types import FunctionType
  8. class TrainLoop(object):
  9. def __init__(
  10. self,
  11. model: Model,
  12. lr: float = 0.001,
  13. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  14. torch.nn.functional.binary_cross_entropy_with_logits,
  15. batch_size: int = 100,
  16. shuffle: bool = False,
  17. generator: torch.Generator = None) -> None:
  18. if not isinstance(model, Model):
  19. raise TypeError('model must be an instance of Model')
  20. lr = float(lr)
  21. if not isinstance(loss, FunctionType):
  22. raise TypeError('loss must be a function')
  23. batch_size = int(batch_size)
  24. if generator is not None and not isinstance(generator, torch.Generator):
  25. raise TypeError('generator must be an instance of torch.Generator')
  26. self.model = model
  27. self.lr = lr
  28. self.loss = loss
  29. self.batch_size = batch_size
  30. self.shuffle = shuffle
  31. self.generator = generator or torch.default_generator
  32. self.opt = None
  33. self.build()
  34. def build(self) -> None:
  35. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  36. self.opt = opt
  37. def run_epoch(self):
  38. batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
  39. shuffle = self.shuffle, generator=self.generator)
  40. # pred = self.model(None)
  41. # n = len(list(iter(batch)))
  42. loss_sum = 0
  43. for i, indices in enumerate(batch):
  44. print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
  45. self.opt.zero_grad()
  46. pred = self.model(None)
  47. pred = flatten_predictions(pred)
  48. # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  49. # seed = torch.rand(1).item()
  50. # rng_state = torch.get_rng_state()
  51. # torch.manual_seed(seed)
  52. #it = iter(batch)
  53. #torch.set_rng_state(rng_state)
  54. #for k in range(i):
  55. #_ = next(it)
  56. #(input, target) = next(it)
  57. (input, target) = gather_batch_indices(pred, indices)
  58. loss = self.loss(input, target)
  59. loss.backward()
  60. self.opt.step()
  61. loss_sum += loss.detach().cpu().item()
  62. return loss_sum
  63. def train(self, max_epochs):
  64. best_loss = None
  65. best_epoch = None
  66. for i in range(max_epochs):
  67. loss = self.run_epoch()
  68. if best_loss is None or loss < best_loss:
  69. best_loss = loss
  70. best_epoch = i
  71. return loss, best_loss, best_epoch