|
- from .model import Model
- from .batch import Batcher
-
-
- class TrainLoop(object):
- def __init__(self, model: Model,
- pos_batcher: Batcher,
- neg_batcher: Batcher,
- max_epochs: int = 50) -> None:
-
- if not isinstance(model, Model):
- raise TypeError('model must be an instance of Model')
-
- if not isinstance(pos_batcher, Batcher):
- raise TypeError('pos_batcher must be an instance of Batcher')
-
- if not isinstance(neg_batcher, Batcher):
- raise TypeError('neg_batcher must be an instance of Batcher')
-
- self.model = model
- self.pos_batcher = pos_batcher
- self.neg_batcher = neg_batcher
- self.max_epochs = int(num_epochs)
-
- def run_epoch(self) -> None:
- pos_it = iter(self.pos_batcher)
- neg_it = iter(self.neg_batcher)
-
- while True:
- try:
- pos_batch = next(pos_it)
- neg_batch = next(neg_it)
- except StopIteration:
- break
- if len(pos_batch.edges) != len(neg_batch.edges):
- raise ValueError('Positive and negative batch should have same length')
-
- def run(self) -> None:
- for epoch in range(self.max_epochs):
- self.run_epoch()
|