| @@ -7,6 +7,7 @@ from typing import List, \ | |||||
| Callable | Callable | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| import time | |||||
| class Convolutions(torch.nn.Module): | class Convolutions(torch.nn.Module): | ||||
| @@ -104,6 +105,7 @@ class DecagonLayer(torch.nn.Module): | |||||
| self.build_family(fam) | self.build_family(fam) | ||||
| def __call__(self, prev_layer_repr): | def __call__(self, prev_layer_repr): | ||||
| t = time.time() | |||||
| next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
| n = len(self.data.node_types) | n = len(self.data.node_types) | ||||
| @@ -120,4 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||||
| next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | ||||
| next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | ||||
| print('DecagonLayer.forward() took', time.time() - t) | |||||
| return next_layer_repr | return next_layer_repr | ||||
| @@ -16,6 +16,7 @@ from typing import Type, \ | |||||
| Tuple | Tuple | ||||
| from .decode import DEDICOMDecoder | from .decode import DEDICOMDecoder | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| import time | |||||
| @dataclass | @dataclass | ||||
| @@ -75,6 +76,7 @@ class DecodeLayer(torch.nn.Module): | |||||
| self.decoders.append(dec) | self.decoders.append(dec) | ||||
| def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | ||||
| start_time = time.time() | |||||
| pred = [] | pred = [] | ||||
| for p in edge_list_attr_names: | for p in edge_list_attr_names: | ||||
| tvt = [] | tvt = [] | ||||
| @@ -86,9 +88,11 @@ class DecodeLayer(torch.nn.Module): | |||||
| tvt.append(dec(inputs_row, inputs_column, k)) | tvt.append(dec(inputs_row, inputs_column, k)) | ||||
| tvt = TrainValTest(*tvt) | tvt = TrainValTest(*tvt) | ||||
| pred.append(tvt) | pred.append(tvt) | ||||
| print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||||
| return pred | return pred | ||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | ||||
| t = time.time() | |||||
| res = [] | res = [] | ||||
| for i, fam in enumerate(self.data.relation_families): | for i, fam in enumerate(self.data.relation_families): | ||||
| fam_pred = [] | fam_pred = [] | ||||
| @@ -103,4 +107,5 @@ class DecodeLayer(torch.nn.Module): | |||||
| fam_pred = RelationFamilyPredictions(fam_pred) | fam_pred = RelationFamilyPredictions(fam_pred) | ||||
| res.append(fam_pred) | res.append(fam_pred) | ||||
| res = Predictions(res) | res = Predictions(res) | ||||
| print('DecodeLayer.forward() took', time.time() - t) | |||||
| return res | return res | ||||
| @@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ | |||||
| gather_batch_indices | gather_batch_indices | ||||
| from typing import Callable | from typing import Callable | ||||
| from types import FunctionType | from types import FunctionType | ||||
| import time | |||||
| class TrainLoop(object): | class TrainLoop(object): | ||||
| @@ -54,9 +55,15 @@ class TrainLoop(object): | |||||
| loss_sum = 0 | loss_sum = 0 | ||||
| for i, indices in enumerate(batch): | for i, indices in enumerate(batch): | ||||
| print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | ||||
| t = time.time() | |||||
| self.opt.zero_grad() | self.opt.zero_grad() | ||||
| print('zero_grad() took:', time.time() - t) | |||||
| t = time.time() | |||||
| pred = self.model(None) | pred = self.model(None) | ||||
| print('model() took:', time.time() - t) | |||||
| t = time.time() | |||||
| pred = flatten_predictions(pred) | pred = flatten_predictions(pred) | ||||
| print('flatten_predictions() took:', time.time() - t) | |||||
| # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | ||||
| # seed = torch.rand(1).item() | # seed = torch.rand(1).item() | ||||
| # rng_state = torch.get_rng_state() | # rng_state = torch.get_rng_state() | ||||
| @@ -66,10 +73,18 @@ class TrainLoop(object): | |||||
| #for k in range(i): | #for k in range(i): | ||||
| #_ = next(it) | #_ = next(it) | ||||
| #(input, target) = next(it) | #(input, target) = next(it) | ||||
| t = time.time() | |||||
| (input, target) = gather_batch_indices(pred, indices) | (input, target) = gather_batch_indices(pred, indices) | ||||
| print('gather_batch_indices() took:', time.time() - t) | |||||
| t = time.time() | |||||
| loss = self.loss(input, target) | loss = self.loss(input, target) | ||||
| print('loss() took:', time.time() - t) | |||||
| t = time.time() | |||||
| loss.backward() | loss.backward() | ||||
| print('backward() took:', time.time() - t) | |||||
| t = time.time() | |||||
| self.opt.step() | self.opt.step() | ||||
| print('step() took:', time.time() - t) | |||||
| loss_sum += loss.detach().cpu().item() | loss_sum += loss.detach().cpu().item() | ||||
| return loss_sum | return loss_sum | ||||
| @@ -6,6 +6,7 @@ from icosagon.trainloop import TrainLoop | |||||
| import torch | import torch | ||||
| import pytest | import pytest | ||||
| import pdb | import pdb | ||||
| import time | |||||
| def test_train_loop_01(): | def test_train_loop_01(): | ||||
| @@ -69,3 +70,12 @@ def test_train_loop_03(): | |||||
| loop = TrainLoop(m) | loop = TrainLoop(m) | ||||
| loop.run_epoch() | loop.run_epoch() | ||||
| def test_timing_01(): | |||||
| adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() | |||||
| rep = torch.eye(2000).requires_grad_(True) | |||||
| t = time.time() | |||||
| for _ in range(1300): | |||||
| _ = torch.sparse.mm(adj_mat, rep) | |||||
| print('Elapsed:', time.time() - t) | |||||