from .model import Model from .batch import Batcher class TrainLoop(object): def __init__(self, model: Model, pos_batcher: Batcher, neg_batcher: Batcher, max_epochs: int = 50) -> None: if not isinstance(model, Model): raise TypeError('model must be an instance of Model') if not isinstance(pos_batcher, Batcher): raise TypeError('pos_batcher must be an instance of Batcher') if not isinstance(neg_batcher, Batcher): raise TypeError('neg_batcher must be an instance of Batcher') self.model = model self.pos_batcher = pos_batcher self.neg_batcher = neg_batcher self.max_epochs = int(num_epochs) def run_epoch(self) -> None: pos_it = iter(self.pos_batcher) neg_it = iter(self.neg_batcher) while True: try: pos_batch = next(pos_it) neg_batch = next(neg_it) except StopIteration: break if len(pos_batch.edges) != len(neg_batch.edges): raise ValueError('Positive and negative batch should have same length') def run(self) -> None: for epoch in range(self.max_epochs): self.run_epoch()