diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index e98b55e..9b8e5ae 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -7,6 +7,7 @@ from typing import List, \ Callable from collections import defaultdict from dataclasses import dataclass +import time class Convolutions(torch.nn.Module): @@ -104,6 +105,7 @@ class DecagonLayer(torch.nn.Module): self.build_family(fam) def __call__(self, prev_layer_repr): + t = time.time() next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] n = len(self.data.node_types) @@ -120,4 +122,5 @@ class DecagonLayer(torch.nn.Module): next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) + print('DecagonLayer.forward() took', time.time() - t) return next_layer_repr diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 8f6bff4..13f751d 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -16,6 +16,7 @@ from typing import Type, \ Tuple from .decode import DEDICOMDecoder from dataclasses import dataclass +import time @dataclass @@ -75,6 +76,7 @@ class DecodeLayer(torch.nn.Module): self.decoders.append(dec) def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): + start_time = time.time() pred = [] for p in edge_list_attr_names: tvt = [] @@ -86,9 +88,11 @@ class DecodeLayer(torch.nn.Module): tvt.append(dec(inputs_row, inputs_column, k)) tvt = TrainValTest(*tvt) pred.append(tvt) + print('DecodeLayer._get_tvt() took:', time.time() - start_time) return pred def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: + t = time.time() res = [] for i, fam in enumerate(self.data.relation_families): fam_pred = [] @@ -103,4 +107,5 @@ class DecodeLayer(torch.nn.Module): fam_pred = RelationFamilyPredictions(fam_pred) res.append(fam_pred) res = Predictions(res) + print('DecodeLayer.forward() took', time.time() - t) return res diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py index 098baba..40cb122 100644 --- a/src/icosagon/trainloop.py +++ b/src/icosagon/trainloop.py @@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ gather_batch_indices from typing import Callable from types import FunctionType +import time class TrainLoop(object): @@ -54,9 +55,15 @@ class TrainLoop(object): loss_sum = 0 for i, indices in enumerate(batch): print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) + t = time.time() self.opt.zero_grad() + print('zero_grad() took:', time.time() - t) + t = time.time() pred = self.model(None) + print('model() took:', time.time() - t) + t = time.time() pred = flatten_predictions(pred) + print('flatten_predictions() took:', time.time() - t) # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) # seed = torch.rand(1).item() # rng_state = torch.get_rng_state() @@ -66,10 +73,18 @@ class TrainLoop(object): #for k in range(i): #_ = next(it) #(input, target) = next(it) + t = time.time() (input, target) = gather_batch_indices(pred, indices) + print('gather_batch_indices() took:', time.time() - t) + t = time.time() loss = self.loss(input, target) + print('loss() took:', time.time() - t) + t = time.time() loss.backward() + print('backward() took:', time.time() - t) + t = time.time() self.opt.step() + print('step() took:', time.time() - t) loss_sum += loss.detach().cpu().item() return loss_sum diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py index 5271a06..be6273c 100644 --- a/tests/icosagon/test_trainloop.py +++ b/tests/icosagon/test_trainloop.py @@ -6,6 +6,7 @@ from icosagon.trainloop import TrainLoop import torch import pytest import pdb +import time def test_train_loop_01(): @@ -69,3 +70,12 @@ def test_train_loop_03(): loop = TrainLoop(m) loop.run_epoch() + + +def test_timing_01(): + adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() + rep = torch.eye(2000).requires_grad_(True) + t = time.time() + for _ in range(1300): + _ = torch.sparse.mm(adj_mat, rep) + print('Elapsed:', time.time() - t)