IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

41 lines
1.3KB

  1. from .model import Model
  2. from .batch import Batcher
  3. class TrainLoop(object):
  4. def __init__(self, model: Model,
  5. pos_batcher: Batcher,
  6. neg_batcher: Batcher,
  7. max_epochs: int = 50) -> None:
  8. if not isinstance(model, Model):
  9. raise TypeError('model must be an instance of Model')
  10. if not isinstance(pos_batcher, Batcher):
  11. raise TypeError('pos_batcher must be an instance of Batcher')
  12. if not isinstance(neg_batcher, Batcher):
  13. raise TypeError('neg_batcher must be an instance of Batcher')
  14. self.model = model
  15. self.pos_batcher = pos_batcher
  16. self.neg_batcher = neg_batcher
  17. self.max_epochs = int(num_epochs)
  18. def run_epoch(self) -> None:
  19. pos_it = iter(self.pos_batcher)
  20. neg_it = iter(self.neg_batcher)
  21. while True:
  22. try:
  23. pos_batch = next(pos_it)
  24. neg_batch = next(neg_it)
  25. except StopIteration:
  26. break
  27. if len(pos_batch.edges) != len(neg_batch.edges):
  28. raise ValueError('Positive and negative batch should have same length')
  29. def run(self) -> None:
  30. for epoch in range(self.max_epochs):
  31. self.run_epoch()