IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add shuffle param to TrainLoop.

master
Stanislaw Adaszewski 4 years ago
parent
commit
b9e7e395cb
1 changed files with 9 additions and 3 deletions
  1. +9
    -3
      src/icosagon/trainloop.py

+ 9
- 3
src/icosagon/trainloop.py View File

@@ -8,10 +8,15 @@ from types import FunctionType
class TrainLoop(object): class TrainLoop(object):
def __init__(self, model: Model, lr: float = 0.001,
def __init__(
self,
model: Model,
lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits, torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100, generator: torch.Generator = None) -> None:
batch_size: int = 100,
shuffle: bool = False,
generator: torch.Generator = None) -> None:
if not isinstance(model, Model): if not isinstance(model, Model):
raise TypeError('model must be an instance of Model') raise TypeError('model must be an instance of Model')
@@ -30,6 +35,7 @@ class TrainLoop(object):
self.lr = lr self.lr = lr
self.loss = loss self.loss = loss
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle
self.generator = generator or torch.default_generator self.generator = generator or torch.default_generator
self.opt = None self.opt = None
@@ -42,7 +48,7 @@ class TrainLoop(object):
def run_epoch(self): def run_epoch(self):
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
generator=self.generator)
shuffle = self.shuffle, generator=self.generator)
# pred = self.model(None) # pred = self.model(None)
# n = len(list(iter(batch))) # n = len(list(iter(batch)))
loss_sum = 0 loss_sum = 0


Loading…
Cancel
Save