|
|
@@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ |
|
|
|
gather_batch_indices
|
|
|
|
from typing import Callable
|
|
|
|
from types import FunctionType
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop(object):
|
|
|
@@ -54,9 +55,15 @@ class TrainLoop(object): |
|
|
|
loss_sum = 0
|
|
|
|
for i, indices in enumerate(batch):
|
|
|
|
print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
|
|
|
|
t = time.time()
|
|
|
|
self.opt.zero_grad()
|
|
|
|
print('zero_grad() took:', time.time() - t)
|
|
|
|
t = time.time()
|
|
|
|
pred = self.model(None)
|
|
|
|
print('model() took:', time.time() - t)
|
|
|
|
t = time.time()
|
|
|
|
pred = flatten_predictions(pred)
|
|
|
|
print('flatten_predictions() took:', time.time() - t)
|
|
|
|
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
|
|
|
|
# seed = torch.rand(1).item()
|
|
|
|
# rng_state = torch.get_rng_state()
|
|
|
@@ -66,10 +73,18 @@ class TrainLoop(object): |
|
|
|
#for k in range(i):
|
|
|
|
#_ = next(it)
|
|
|
|
#(input, target) = next(it)
|
|
|
|
t = time.time()
|
|
|
|
(input, target) = gather_batch_indices(pred, indices)
|
|
|
|
print('gather_batch_indices() took:', time.time() - t)
|
|
|
|
t = time.time()
|
|
|
|
loss = self.loss(input, target)
|
|
|
|
print('loss() took:', time.time() - t)
|
|
|
|
t = time.time()
|
|
|
|
loss.backward()
|
|
|
|
print('backward() took:', time.time() - t)
|
|
|
|
t = time.time()
|
|
|
|
self.opt.step()
|
|
|
|
print('step() took:', time.time() - t)
|
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|