@@ -7,6 +7,7 @@ from typing import List, \ | |||
Callable | |||
from collections import defaultdict | |||
from dataclasses import dataclass | |||
import time | |||
class Convolutions(torch.nn.Module): | |||
@@ -104,6 +105,7 @@ class DecagonLayer(torch.nn.Module): | |||
self.build_family(fam) | |||
def __call__(self, prev_layer_repr): | |||
t = time.time() | |||
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
n = len(self.data.node_types) | |||
@@ -120,4 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||
next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | |||
next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | |||
print('DecagonLayer.forward() took', time.time() - t) | |||
return next_layer_repr |
@@ -16,6 +16,7 @@ from typing import Type, \ | |||
Tuple | |||
from .decode import DEDICOMDecoder | |||
from dataclasses import dataclass | |||
import time | |||
@dataclass | |||
@@ -75,6 +76,7 @@ class DecodeLayer(torch.nn.Module): | |||
self.decoders.append(dec) | |||
def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | |||
start_time = time.time() | |||
pred = [] | |||
for p in edge_list_attr_names: | |||
tvt = [] | |||
@@ -86,9 +88,11 @@ class DecodeLayer(torch.nn.Module): | |||
tvt.append(dec(inputs_row, inputs_column, k)) | |||
tvt = TrainValTest(*tvt) | |||
pred.append(tvt) | |||
print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
return pred | |||
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||
t = time.time() | |||
res = [] | |||
for i, fam in enumerate(self.data.relation_families): | |||
fam_pred = [] | |||
@@ -103,4 +107,5 @@ class DecodeLayer(torch.nn.Module): | |||
fam_pred = RelationFamilyPredictions(fam_pred) | |||
res.append(fam_pred) | |||
res = Predictions(res) | |||
print('DecodeLayer.forward() took', time.time() - t) | |||
return res |
@@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ | |||
gather_batch_indices | |||
from typing import Callable | |||
from types import FunctionType | |||
import time | |||
class TrainLoop(object): | |||
@@ -54,9 +55,15 @@ class TrainLoop(object): | |||
loss_sum = 0 | |||
for i, indices in enumerate(batch): | |||
print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | |||
t = time.time() | |||
self.opt.zero_grad() | |||
print('zero_grad() took:', time.time() - t) | |||
t = time.time() | |||
pred = self.model(None) | |||
print('model() took:', time.time() - t) | |||
t = time.time() | |||
pred = flatten_predictions(pred) | |||
print('flatten_predictions() took:', time.time() - t) | |||
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
# seed = torch.rand(1).item() | |||
# rng_state = torch.get_rng_state() | |||
@@ -66,10 +73,18 @@ class TrainLoop(object): | |||
#for k in range(i): | |||
#_ = next(it) | |||
#(input, target) = next(it) | |||
t = time.time() | |||
(input, target) = gather_batch_indices(pred, indices) | |||
print('gather_batch_indices() took:', time.time() - t) | |||
t = time.time() | |||
loss = self.loss(input, target) | |||
print('loss() took:', time.time() - t) | |||
t = time.time() | |||
loss.backward() | |||
print('backward() took:', time.time() - t) | |||
t = time.time() | |||
self.opt.step() | |||
print('step() took:', time.time() - t) | |||
loss_sum += loss.detach().cpu().item() | |||
return loss_sum | |||
@@ -6,6 +6,7 @@ from icosagon.trainloop import TrainLoop | |||
import torch | |||
import pytest | |||
import pdb | |||
import time | |||
def test_train_loop_01(): | |||
@@ -69,3 +70,12 @@ def test_train_loop_03(): | |||
loop = TrainLoop(m) | |||
loop.run_epoch() | |||
def test_timing_01(): | |||
adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() | |||
rep = torch.eye(2000).requires_grad_(True) | |||
t = time.time() | |||
for _ in range(1300): | |||
_ = torch.sparse.mm(adj_mat, rep) | |||
print('Elapsed:', time.time() - t) |