From 346cc747a6007c6e9e1ef740439440d23e6be821 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 16 Jun 2020 18:12:44 +0200 Subject: [PATCH] Add shuffle to PredictionsBatch. --- src/icosagon/batch.py | 10 +++++++++- src/icosagon/model.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/icosagon/batch.py b/src/icosagon/batch.py index 275cac3..9a28712 100644 --- a/src/icosagon/batch.py +++ b/src/icosagon/batch.py @@ -4,7 +4,7 @@ import torch class PredictionsBatch(object): def __init__(self, pred: Predictions, part_type: str = 'train', - batch_size: int = 100) -> None: + batch_size: int = 100, shuffle: bool = False) -> None: if not isinstance(pred, Predictions): raise TypeError('pred must be an instance of Predictions') @@ -14,9 +14,12 @@ class PredictionsBatch(object): batch_size = int(batch_size) + shuffle = bool(shuffle) + self.predictions = pred self.part_type = part_type self.batch_size = batch_size + self.shuffle = shuffle def __iter__(self): edge_types = [('edges_pos', 1), ('edges_neg', 0), @@ -35,5 +38,10 @@ class PredictionsBatch(object): input = torch.cat(input) target = torch.cat(target) + if self.shuffle: + perm = torch.randperm(len(input)) + input = input[perm] + target = target[perm] + for i in range(0, len(input), self.batch_size): yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) diff --git a/src/icosagon/model.py b/src/icosagon/model.py index ba28bb7..4c1cf0d 100644 --- a/src/icosagon/model.py +++ b/src/icosagon/model.py @@ -100,7 +100,7 @@ class Model(object): for i in range(n - 1): self.opt.zero_grad() pred = self.seq(None) - batch = PredictionsBatch(pred, self.batch_size) + batch = PredictionsBatch(pred, self.batch_size, shuffle=True) seed = torch.rand(1).item() rng_state = torch.get_rng_state() torch.manual_seed(seed)