|
@@ -99,13 +99,13 @@ class Model(object): |
|
|
|
|
|
|
|
|
def run_epoch(self):
|
|
|
def run_epoch(self):
|
|
|
pred = self.seq(None)
|
|
|
pred = self.seq(None)
|
|
|
batch = PredictionsBatch(pred, self.batch_size)
|
|
|
|
|
|
|
|
|
batch = PredictionsBatch(pred, batch_size=self.batch_size)
|
|
|
n = len(list(iter(batch)))
|
|
|
n = len(list(iter(batch)))
|
|
|
loss_sum = 0
|
|
|
loss_sum = 0
|
|
|
for i in range(n - 1):
|
|
|
|
|
|
|
|
|
for i in range(n):
|
|
|
self.opt.zero_grad()
|
|
|
self.opt.zero_grad()
|
|
|
pred = self.seq(None)
|
|
|
pred = self.seq(None)
|
|
|
batch = PredictionsBatch(pred, self.batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
|
|
|
seed = torch.rand(1).item()
|
|
|
seed = torch.rand(1).item()
|
|
|
rng_state = torch.get_rng_state()
|
|
|
rng_state = torch.get_rng_state()
|
|
|
torch.manual_seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
@@ -116,7 +116,7 @@ class Model(object): |
|
|
(input, target) = next(it)
|
|
|
(input, target) = next(it)
|
|
|
loss = self.loss(input, target)
|
|
|
loss = self.loss(input, target)
|
|
|
loss.backward()
|
|
|
loss.backward()
|
|
|
self.opt.optimize()
|
|
|
|
|
|
|
|
|
self.opt.step()
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
return loss_sum
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
|