|
@@ -118,7 +118,7 @@ def main(): |
|
|
# model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1'])
|
|
|
# model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1'])
|
|
|
_wrap(TrainLoop, 'build')
|
|
|
_wrap(TrainLoop, 'build')
|
|
|
_wrap(TrainLoop, 'run_epoch')
|
|
|
_wrap(TrainLoop, 'run_epoch')
|
|
|
loop = TrainLoop(model, batch_size=1000000)
|
|
|
|
|
|
|
|
|
loop = TrainLoop(model, batch_size=512, shuffle=True)
|
|
|
loop.run_epoch()
|
|
|
loop.run_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|