| @@ -122,5 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||
| next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | |||
| next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | |||
| print('DecagonLayer.forward() took', time.time() - t) | |||
| # print('DecagonLayer.forward() took', time.time() - t) | |||
| return next_layer_repr | |||
| @@ -11,6 +11,11 @@ class BatchedData(PreparedData): | |||
| super().__init__(*args, **kwargs) | |||
| class BatchedDataPointer(object): | |||
| def __init__(self, batched_data): | |||
| self.batched_data = batched_data | |||
| def batched_data_skeleton(data: PreparedData) -> BatchedData: | |||
| if not isinstance(data, PreparedData): | |||
| raise TypeError('data must be an instance of PreparedData') | |||
| @@ -17,6 +17,7 @@ from typing import Type, \ | |||
| from .decode import DEDICOMDecoder | |||
| from dataclasses import dataclass | |||
| import time | |||
| from .databatch import BatchedDataPointer | |||
| @dataclass | |||
| @@ -43,6 +44,7 @@ class DecodeLayer(torch.nn.Module): | |||
| data: PreparedData, | |||
| keep_prob: float = 1., | |||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
| batched_data_pointer: BatchedDataPointer = None, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| @@ -59,11 +61,19 @@ class DecodeLayer(torch.nn.Module): | |||
| if not isinstance(data, PreparedData): | |||
| raise TypeError('data must be an instance of PreparedData') | |||
| if batched_data_pointer is not None and \ | |||
| not isinstance(batched_data_pointer, BatchedDataPointer): | |||
| raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer') | |||
| # if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data): | |||
| # raise ValueError('batched_data_pointer must be compatible with data') | |||
| self.input_dim = input_dim[0] | |||
| self.output_dim = 1 | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.batched_data_pointer = batched_data_pointer | |||
| self.decoders = None | |||
| self.build() | |||
| @@ -88,13 +98,16 @@ class DecodeLayer(torch.nn.Module): | |||
| tvt.append(dec(inputs_row, inputs_column, k)) | |||
| tvt = TrainValTest(*tvt) | |||
| pred.append(tvt) | |||
| print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
| # print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
| return pred | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||
| t = time.time() | |||
| res = [] | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| data = self.batched_data_pointer.batched_data \ | |||
| if self.batched_data_pointer is not None \ | |||
| else self.data | |||
| for i, fam in enumerate(data.relation_families): | |||
| fam_pred = [] | |||
| for k, r in enumerate(fam.relation_types): | |||
| pred = [] | |||
| @@ -107,5 +120,5 @@ class DecodeLayer(torch.nn.Module): | |||
| fam_pred = RelationFamilyPredictions(fam_pred) | |||
| res.append(fam_pred) | |||
| res = Predictions(res) | |||
| print('DecodeLayer.forward() took', time.time() - t) | |||
| # print('DecodeLayer.forward() took', time.time() - t) | |||
| return res | |||
| @@ -1,9 +1,14 @@ | |||
| from icosagon.databatch import DataBatcher, \ | |||
| BatchedData | |||
| BatchedData, \ | |||
| BatchedDataPointer, \ | |||
| batched_data_skeleton | |||
| from icosagon.data import Data | |||
| from icosagon.trainprep import prepare_training, \ | |||
| TrainValTest | |||
| from icosagon.declayer import DecodeLayer | |||
| from icosagon.input import OneHotInputLayer | |||
| import torch | |||
| import time | |||
| def _some_data(): | |||
| @@ -16,6 +21,16 @@ def _some_data(): | |||
| return data | |||
| def _some_data_big(): | |||
| data = Data() | |||
| data.add_node_type('Foo', 2000) | |||
| data.add_node_type('Bar', 2100) | |||
| fam = data.add_relation_family('Foo-Bar', 0, 1, True) | |||
| adj_mat = torch.rand(2000, 2100).round().to_sparse() | |||
| fam.add_relation_type('Foo-Bar', adj_mat) | |||
| return data | |||
| def test_data_batcher_01(): | |||
| data = _some_data() | |||
| prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
| @@ -79,3 +94,33 @@ def test_data_batcher_05(): | |||
| assert all([ len(edges) <= 512 for edges in edges_list ]) | |||
| assert not all([ len(edges) == 0 for edges in edges_list ]) | |||
| print(sum(map(len, edges_list))) | |||
| def test_batch_decode_01(): | |||
| data = _some_data() | |||
| prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
| batcher = DataBatcher(prep_d, 512) | |||
| ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||
| in_repr = [ torch.rand(100, 32), | |||
| torch.rand(500, 32) ] | |||
| dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||
| t = time.time() | |||
| for batched_data in batcher: | |||
| ptr.batched_data = batched_data | |||
| _ = dec_layer(in_repr) | |||
| print('Elapsed:', time.time() - t) | |||
| def test_batch_decode_02(): | |||
| data = _some_data_big() | |||
| prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
| batcher = DataBatcher(prep_d, 512) | |||
| ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||
| in_repr = [ torch.rand(2000, 32), | |||
| torch.rand(2100, 32) ] | |||
| dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||
| t = time.time() | |||
| for batched_data in batcher: | |||
| ptr.batched_data = batched_data | |||
| _ = dec_layer(in_repr) | |||
| print('Elapsed:', time.time() - t) | |||