diff --git a/experiments/decagon_run/decagon_run.py b/experiments/decagon_run/decagon_run.py index 911caaa..74e1a03 100644 --- a/experiments/decagon_run/decagon_run.py +++ b/experiments/decagon_run/decagon_run.py @@ -118,7 +118,7 @@ def main(): # model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1']) _wrap(TrainLoop, 'build') _wrap(TrainLoop, 'run_epoch') - loop = TrainLoop(model, batch_size=1000000) + loop = TrainLoop(model, batch_size=512, shuffle=True) loop.run_epoch()