@@ -7,6 +7,7 @@ from typing import List, \ | |||||
Callable | Callable | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
import time | |||||
class Convolutions(torch.nn.Module): | class Convolutions(torch.nn.Module): | ||||
@@ -104,6 +105,7 @@ class DecagonLayer(torch.nn.Module): | |||||
self.build_family(fam) | self.build_family(fam) | ||||
def __call__(self, prev_layer_repr): | def __call__(self, prev_layer_repr): | ||||
t = time.time() | |||||
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
n = len(self.data.node_types) | n = len(self.data.node_types) | ||||
@@ -120,4 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||||
next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | ||||
next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | ||||
print('DecagonLayer.forward() took', time.time() - t) | |||||
return next_layer_repr | return next_layer_repr |
@@ -16,6 +16,7 @@ from typing import Type, \ | |||||
Tuple | Tuple | ||||
from .decode import DEDICOMDecoder | from .decode import DEDICOMDecoder | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
import time | |||||
@dataclass | @dataclass | ||||
@@ -75,6 +76,7 @@ class DecodeLayer(torch.nn.Module): | |||||
self.decoders.append(dec) | self.decoders.append(dec) | ||||
def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | ||||
start_time = time.time() | |||||
pred = [] | pred = [] | ||||
for p in edge_list_attr_names: | for p in edge_list_attr_names: | ||||
tvt = [] | tvt = [] | ||||
@@ -86,9 +88,11 @@ class DecodeLayer(torch.nn.Module): | |||||
tvt.append(dec(inputs_row, inputs_column, k)) | tvt.append(dec(inputs_row, inputs_column, k)) | ||||
tvt = TrainValTest(*tvt) | tvt = TrainValTest(*tvt) | ||||
pred.append(tvt) | pred.append(tvt) | ||||
print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||||
return pred | return pred | ||||
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | ||||
t = time.time() | |||||
res = [] | res = [] | ||||
for i, fam in enumerate(self.data.relation_families): | for i, fam in enumerate(self.data.relation_families): | ||||
fam_pred = [] | fam_pred = [] | ||||
@@ -103,4 +107,5 @@ class DecodeLayer(torch.nn.Module): | |||||
fam_pred = RelationFamilyPredictions(fam_pred) | fam_pred = RelationFamilyPredictions(fam_pred) | ||||
res.append(fam_pred) | res.append(fam_pred) | ||||
res = Predictions(res) | res = Predictions(res) | ||||
print('DecodeLayer.forward() took', time.time() - t) | |||||
return res | return res |
@@ -5,6 +5,7 @@ from .batch import PredictionsBatch, \ | |||||
gather_batch_indices | gather_batch_indices | ||||
from typing import Callable | from typing import Callable | ||||
from types import FunctionType | from types import FunctionType | ||||
import time | |||||
class TrainLoop(object): | class TrainLoop(object): | ||||
@@ -54,9 +55,15 @@ class TrainLoop(object): | |||||
loss_sum = 0 | loss_sum = 0 | ||||
for i, indices in enumerate(batch): | for i, indices in enumerate(batch): | ||||
print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count)) | ||||
t = time.time() | |||||
self.opt.zero_grad() | self.opt.zero_grad() | ||||
print('zero_grad() took:', time.time() - t) | |||||
t = time.time() | |||||
pred = self.model(None) | pred = self.model(None) | ||||
print('model() took:', time.time() - t) | |||||
t = time.time() | |||||
pred = flatten_predictions(pred) | pred = flatten_predictions(pred) | ||||
print('flatten_predictions() took:', time.time() - t) | |||||
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | ||||
# seed = torch.rand(1).item() | # seed = torch.rand(1).item() | ||||
# rng_state = torch.get_rng_state() | # rng_state = torch.get_rng_state() | ||||
@@ -66,10 +73,18 @@ class TrainLoop(object): | |||||
#for k in range(i): | #for k in range(i): | ||||
#_ = next(it) | #_ = next(it) | ||||
#(input, target) = next(it) | #(input, target) = next(it) | ||||
t = time.time() | |||||
(input, target) = gather_batch_indices(pred, indices) | (input, target) = gather_batch_indices(pred, indices) | ||||
print('gather_batch_indices() took:', time.time() - t) | |||||
t = time.time() | |||||
loss = self.loss(input, target) | loss = self.loss(input, target) | ||||
print('loss() took:', time.time() - t) | |||||
t = time.time() | |||||
loss.backward() | loss.backward() | ||||
print('backward() took:', time.time() - t) | |||||
t = time.time() | |||||
self.opt.step() | self.opt.step() | ||||
print('step() took:', time.time() - t) | |||||
loss_sum += loss.detach().cpu().item() | loss_sum += loss.detach().cpu().item() | ||||
return loss_sum | return loss_sum | ||||
@@ -6,6 +6,7 @@ from icosagon.trainloop import TrainLoop | |||||
import torch | import torch | ||||
import pytest | import pytest | ||||
import pdb | import pdb | ||||
import time | |||||
def test_train_loop_01(): | def test_train_loop_01(): | ||||
@@ -69,3 +70,12 @@ def test_train_loop_03(): | |||||
loop = TrainLoop(m) | loop = TrainLoop(m) | ||||
loop.run_epoch() | loop.run_epoch() | ||||
def test_timing_01(): | |||||
adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() | |||||
rep = torch.eye(2000).requires_grad_(True) | |||||
t = time.time() | |||||
for _ in range(1300): | |||||
_ = torch.sparse.mm(adj_mat, rep) | |||||
print('Elapsed:', time.time() - t) |