From 45a18a46aa60e451430c79325d7d861bd6fa96f2 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 21 Jul 2020 14:26:16 +0200 Subject: [PATCH] Add shuffle to DataBatcher. --- src/icosagon/databatch.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/icosagon/databatch.py b/src/icosagon/databatch.py index b7c4a8f..6f96baf 100644 --- a/src/icosagon/databatch.py +++ b/src/icosagon/databatch.py @@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \ PreparedRelationFamily, \ PreparedRelationType, \ _empty_edge_list_tvt +import torch +import random class BatchedData(PreparedData): @@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData: class DataBatcher(object): - def __init__(self, data: PreparedData, batch_size: int) -> None: + def __init__(self, data: PreparedData, batch_size: int, + shuffle: bool = True) -> None: self._check_params(data, batch_size) self.data = data self.batch_size = batch_size + self.shuffle = shuffle # def batched_data_iter(self, fam_idx: int, rel_idx: int, # part_type: str) -> BatchedData: @@ -71,17 +75,34 @@ class DataBatcher(object): # yield batched_data def __iter__(self) -> BatchedData: + gen = self.shuffle_iter() \ + if self.shuffle \ + else self.iter_base() + + for batched_data in gen: + yield batched_data + + def iter_base(self) -> BatchedData: for i, fam in enumerate(self.data.relation_families): for k, rel in enumerate(fam.relation_types): for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: for part_type in ['train', 'val', 'test']: edges = getattr(getattr(rel, edge_type), part_type) + if self.shuffle: + perm = torch.randperm(len(edges)) + edges = edges[perm] for m in range(0, len(edges), self.batch_size): batched_data = batched_data_skeleton(self.data) setattr(getattr(batched_data.relation_families[i].relation_types[k], edge_type), part_type, edges[m : m + self.batch_size]) yield batched_data + def shuffle_iter(self) -> BatchedData: + res = list(self.iter_base()) + random.shuffle(res) + for batched_data in res: + yield batched_data + @staticmethod def _check_params(data, batch_size): if not isinstance(data, PreparedData):