| @@ -7,6 +7,7 @@ from typing import List, \ | |||
| Callable | |||
| from collections import defaultdict | |||
| from dataclasses import dataclass | |||
| import time | |||
| class Convolutions(torch.nn.Module): | |||
| @@ -104,6 +105,7 @@ class DecagonLayer(torch.nn.Module): | |||
| self.build_family(fam) | |||
| def __call__(self, prev_layer_repr): | |||
| t = time.time() | |||
| next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
| n = len(self.data.node_types) | |||
| @@ -120,4 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||
| next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | |||
| next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | |||
| print('DecagonLayer.forward() took', time.time() - t) | |||
| return next_layer_repr | |||
| @@ -16,6 +16,7 @@ from typing import Type, \ | |||
| Tuple | |||
| from .decode import DEDICOMDecoder | |||
| from dataclasses import dataclass | |||
| import time | |||
| @dataclass | |||
| @@ -75,6 +76,7 @@ class DecodeLayer(torch.nn.Module): | |||
| self.decoders.append(dec) | |||
| def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | |||
| start_time = time.time() | |||
| pred = [] | |||
| for p in edge_list_attr_names: | |||
| tvt = [] | |||
| @@ -86,9 +88,11 @@ class DecodeLayer(torch.nn.Module): | |||
| tvt.append(dec(inputs_row, inputs_column, k)) | |||
| tvt = TrainValTest(*tvt) | |||
| pred.append(tvt) | |||
| print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
| return pred | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||
| t = time.time() | |||
| res = [] | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| fam_pred = [] | |||
| @@ -103,4 +107,5 @@ class DecodeLayer(torch.nn.Module): | |||
| fam_pred = RelationFamilyPredictions(fam_pred) | |||
| res.append(fam_pred) | |||
| res = Predictions(res) | |||
| print('DecodeLayer.forward() took', time.time() - t) | |||
| return res | |||
| @@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ | |||
| gather_batch_indices | |||
| from typing import Callable | |||
| from types import FunctionType | |||
| import time | |||
| class TrainLoop(object): | |||
| @@ -54,9 +55,15 @@ class TrainLoop(object): | |||
| loss_sum = 0 | |||
| for i, indices in enumerate(batch): | |||
| print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | |||
| t = time.time() | |||
| self.opt.zero_grad() | |||
| print('zero_grad() took:', time.time() - t) | |||
| t = time.time() | |||
| pred = self.model(None) | |||
| print('model() took:', time.time() - t) | |||
| t = time.time() | |||
| pred = flatten_predictions(pred) | |||
| print('flatten_predictions() took:', time.time() - t) | |||
| # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
| # seed = torch.rand(1).item() | |||
| # rng_state = torch.get_rng_state() | |||
| @@ -66,10 +73,18 @@ class TrainLoop(object): | |||
| #for k in range(i): | |||
| #_ = next(it) | |||
| #(input, target) = next(it) | |||
| t = time.time() | |||
| (input, target) = gather_batch_indices(pred, indices) | |||
| print('gather_batch_indices() took:', time.time() - t) | |||
| t = time.time() | |||
| loss = self.loss(input, target) | |||
| print('loss() took:', time.time() - t) | |||
| t = time.time() | |||
| loss.backward() | |||
| print('backward() took:', time.time() - t) | |||
| t = time.time() | |||
| self.opt.step() | |||
| print('step() took:', time.time() - t) | |||
| loss_sum += loss.detach().cpu().item() | |||
| return loss_sum | |||
| @@ -6,6 +6,7 @@ from icosagon.trainloop import TrainLoop | |||
| import torch | |||
| import pytest | |||
| import pdb | |||
| import time | |||
| def test_train_loop_01(): | |||
| @@ -69,3 +70,12 @@ def test_train_loop_03(): | |||
| loop = TrainLoop(m) | |||
| loop.run_epoch() | |||
| def test_timing_01(): | |||
| adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() | |||
| rep = torch.eye(2000).requires_grad_(True) | |||
| t = time.time() | |||
| for _ in range(1300): | |||
| _ = torch.sparse.mm(adj_mat, rep) | |||
| print('Elapsed:', time.time() - t) | |||