|
|
@@ -8,10 +8,15 @@ from types import FunctionType |
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop(object):
|
|
|
|
def __init__(self, model: Model, lr: float = 0.001,
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: Model,
|
|
|
|
lr: float = 0.001,
|
|
|
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
|
|
|
|
torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
batch_size: int = 100, generator: torch.Generator = None) -> None:
|
|
|
|
batch_size: int = 100,
|
|
|
|
shuffle: bool = False,
|
|
|
|
generator: torch.Generator = None) -> None:
|
|
|
|
|
|
|
|
if not isinstance(model, Model):
|
|
|
|
raise TypeError('model must be an instance of Model')
|
|
|
@@ -30,6 +35,7 @@ class TrainLoop(object): |
|
|
|
self.lr = lr
|
|
|
|
self.loss = loss
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.shuffle = shuffle
|
|
|
|
self.generator = generator or torch.default_generator
|
|
|
|
|
|
|
|
self.opt = None
|
|
|
@@ -42,7 +48,7 @@ class TrainLoop(object): |
|
|
|
|
|
|
|
def run_epoch(self):
|
|
|
|
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
|
|
|
|
generator=self.generator)
|
|
|
|
shuffle = self.shuffle, generator=self.generator)
|
|
|
|
# pred = self.model(None)
|
|
|
|
# n = len(list(iter(batch)))
|
|
|
|
loss_sum = 0
|
|
|
|