diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index 9b8e5ae..3c5b603 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -122,5 +122,5 @@ class DecagonLayer(torch.nn.Module): next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) - print('DecagonLayer.forward() took', time.time() - t) + # print('DecagonLayer.forward() took', time.time() - t) return next_layer_repr diff --git a/src/icosagon/databatch.py b/src/icosagon/databatch.py index 6f96baf..3602d6d 100644 --- a/src/icosagon/databatch.py +++ b/src/icosagon/databatch.py @@ -11,6 +11,11 @@ class BatchedData(PreparedData): super().__init__(*args, **kwargs) +class BatchedDataPointer(object): + def __init__(self, batched_data): + self.batched_data = batched_data + + def batched_data_skeleton(data: PreparedData) -> BatchedData: if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 13f751d..25d9c5f 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -17,6 +17,7 @@ from typing import Type, \ from .decode import DEDICOMDecoder from dataclasses import dataclass import time +from .databatch import BatchedDataPointer @dataclass @@ -43,6 +44,7 @@ class DecodeLayer(torch.nn.Module): data: PreparedData, keep_prob: float = 1., activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, + batched_data_pointer: BatchedDataPointer = None, **kwargs) -> None: super().__init__(**kwargs) @@ -59,11 +61,19 @@ class DecodeLayer(torch.nn.Module): if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') + if batched_data_pointer is not None and \ + not isinstance(batched_data_pointer, BatchedDataPointer): + raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer') + + # if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data): + # raise ValueError('batched_data_pointer must be compatible with data') + self.input_dim = input_dim[0] self.output_dim = 1 self.data = data self.keep_prob = keep_prob self.activation = activation + self.batched_data_pointer = batched_data_pointer self.decoders = None self.build() @@ -88,13 +98,16 @@ class DecodeLayer(torch.nn.Module): tvt.append(dec(inputs_row, inputs_column, k)) tvt = TrainValTest(*tvt) pred.append(tvt) - print('DecodeLayer._get_tvt() took:', time.time() - start_time) + # print('DecodeLayer._get_tvt() took:', time.time() - start_time) return pred def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: t = time.time() res = [] - for i, fam in enumerate(self.data.relation_families): + data = self.batched_data_pointer.batched_data \ + if self.batched_data_pointer is not None \ + else self.data + for i, fam in enumerate(data.relation_families): fam_pred = [] for k, r in enumerate(fam.relation_types): pred = [] @@ -107,5 +120,5 @@ class DecodeLayer(torch.nn.Module): fam_pred = RelationFamilyPredictions(fam_pred) res.append(fam_pred) res = Predictions(res) - print('DecodeLayer.forward() took', time.time() - t) + # print('DecodeLayer.forward() took', time.time() - t) return res diff --git a/tests/icosagon/test_databatch.py b/tests/icosagon/test_databatch.py index 2a35843..b36b5da 100644 --- a/tests/icosagon/test_databatch.py +++ b/tests/icosagon/test_databatch.py @@ -1,9 +1,14 @@ from icosagon.databatch import DataBatcher, \ - BatchedData + BatchedData, \ + BatchedDataPointer, \ + batched_data_skeleton from icosagon.data import Data from icosagon.trainprep import prepare_training, \ TrainValTest +from icosagon.declayer import DecodeLayer +from icosagon.input import OneHotInputLayer import torch +import time def _some_data(): @@ -16,6 +21,16 @@ def _some_data(): return data +def _some_data_big(): + data = Data() + data.add_node_type('Foo', 2000) + data.add_node_type('Bar', 2100) + fam = data.add_relation_family('Foo-Bar', 0, 1, True) + adj_mat = torch.rand(2000, 2100).round().to_sparse() + fam.add_relation_type('Foo-Bar', adj_mat) + return data + + def test_data_batcher_01(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) @@ -79,3 +94,33 @@ def test_data_batcher_05(): assert all([ len(edges) <= 512 for edges in edges_list ]) assert not all([ len(edges) == 0 for edges in edges_list ]) print(sum(map(len, edges_list))) + + +def test_batch_decode_01(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) + in_repr = [ torch.rand(100, 32), + torch.rand(500, 32) ] + dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) + t = time.time() + for batched_data in batcher: + ptr.batched_data = batched_data + _ = dec_layer(in_repr) + print('Elapsed:', time.time() - t) + + +def test_batch_decode_02(): + data = _some_data_big() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) + in_repr = [ torch.rand(2000, 32), + torch.rand(2100, 32) ] + dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) + t = time.time() + for batched_data in batcher: + ptr.batched_data = batched_data + _ = dec_layer(in_repr) + print('Elapsed:', time.time() - t)