IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

80 行
2.8KB

  1. from .model import Model
  2. import torch
  3. from .batch import PredictionsBatch, \
  4. flatten_predictions, \
  5. gather_batch_indices
  6. from typing import Callable
  7. from types import FunctionType
  8. class TrainLoop(object):
  9. def __init__(self, model: Model, lr: float = 0.001,
  10. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  11. torch.nn.functional.binary_cross_entropy_with_logits,
  12. batch_size: int = 100, generator: torch.Generator = None) -> None:
  13. if not isinstance(model, Model):
  14. raise TypeError('model must be an instance of Model')
  15. lr = float(lr)
  16. if not isinstance(loss, FunctionType):
  17. raise TypeError('loss must be a function')
  18. batch_size = int(batch_size)
  19. if generator is not None and not isinstance(generator, torch.Generator):
  20. raise TypeError('generator must be an instance of torch.Generator')
  21. self.model = model
  22. self.lr = lr
  23. self.loss = loss
  24. self.batch_size = batch_size
  25. self.generator = generator or torch.default_generator
  26. self.opt = None
  27. self.build()
  28. def build(self) -> None:
  29. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  30. self.opt = opt
  31. def run_epoch(self):
  32. batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
  33. generator=self.generator)
  34. # pred = self.model(None)
  35. # n = len(list(iter(batch)))
  36. loss_sum = 0
  37. for i, indices in enumerate(batch):
  38. print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
  39. self.opt.zero_grad()
  40. pred = self.model(None)
  41. pred = flatten_predictions(pred)
  42. # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  43. # seed = torch.rand(1).item()
  44. # rng_state = torch.get_rng_state()
  45. # torch.manual_seed(seed)
  46. #it = iter(batch)
  47. #torch.set_rng_state(rng_state)
  48. #for k in range(i):
  49. #_ = next(it)
  50. #(input, target) = next(it)
  51. (input, target) = gather_batch_indices(pred, indices)
  52. loss = self.loss(input, target)
  53. loss.backward()
  54. self.opt.step()
  55. loss_sum += loss.detach().cpu().item()
  56. return loss_sum
  57. def train(self, max_epochs):
  58. best_loss = None
  59. best_epoch = None
  60. for i in range(max_epochs):
  61. loss = self.run_epoch()
  62. if best_loss is None or loss < best_loss:
  63. best_loss = loss
  64. best_epoch = i
  65. return loss, best_loss, best_epoch