IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Add shuffle param to TrainLoop.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
b9e7e395cb
共有 1 個文件被更改,包括 9 次插入3 次删除
  1. +9
    -3
      src/icosagon/trainloop.py

+ 9
- 3
src/icosagon/trainloop.py 查看文件

@@ -8,10 +8,15 @@ from types import FunctionType
class TrainLoop(object):
def __init__(self, model: Model, lr: float = 0.001,
def __init__(
self,
model: Model,
lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100, generator: torch.Generator = None) -> None:
batch_size: int = 100,
shuffle: bool = False,
generator: torch.Generator = None) -> None:
if not isinstance(model, Model):
raise TypeError('model must be an instance of Model')
@@ -30,6 +35,7 @@ class TrainLoop(object):
self.lr = lr
self.loss = loss
self.batch_size = batch_size
self.shuffle = shuffle
self.generator = generator or torch.default_generator
self.opt = None
@@ -42,7 +48,7 @@ class TrainLoop(object):
def run_epoch(self):
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
generator=self.generator)
shuffle = self.shuffle, generator=self.generator)
# pred = self.model(None)
# n = len(list(iter(batch)))
loss_sum = 0


Loading…
取消
儲存