@@ -122,5 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||||
next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | ||||
next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | ||||
print('DecagonLayer.forward() took', time.time() - t) | |||||
# print('DecagonLayer.forward() took', time.time() - t) | |||||
return next_layer_repr | return next_layer_repr |
@@ -11,6 +11,11 @@ class BatchedData(PreparedData): | |||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
class BatchedDataPointer(object): | |||||
def __init__(self, batched_data): | |||||
self.batched_data = batched_data | |||||
def batched_data_skeleton(data: PreparedData) -> BatchedData: | def batched_data_skeleton(data: PreparedData) -> BatchedData: | ||||
if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||
raise TypeError('data must be an instance of PreparedData') | raise TypeError('data must be an instance of PreparedData') | ||||
@@ -17,6 +17,7 @@ from typing import Type, \ | |||||
from .decode import DEDICOMDecoder | from .decode import DEDICOMDecoder | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
import time | import time | ||||
from .databatch import BatchedDataPointer | |||||
@dataclass | @dataclass | ||||
@@ -43,6 +44,7 @@ class DecodeLayer(torch.nn.Module): | |||||
data: PreparedData, | data: PreparedData, | ||||
keep_prob: float = 1., | keep_prob: float = 1., | ||||
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | ||||
batched_data_pointer: BatchedDataPointer = None, | |||||
**kwargs) -> None: | **kwargs) -> None: | ||||
super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
@@ -59,11 +61,19 @@ class DecodeLayer(torch.nn.Module): | |||||
if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||
raise TypeError('data must be an instance of PreparedData') | raise TypeError('data must be an instance of PreparedData') | ||||
if batched_data_pointer is not None and \ | |||||
not isinstance(batched_data_pointer, BatchedDataPointer): | |||||
raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer') | |||||
# if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data): | |||||
# raise ValueError('batched_data_pointer must be compatible with data') | |||||
self.input_dim = input_dim[0] | self.input_dim = input_dim[0] | ||||
self.output_dim = 1 | self.output_dim = 1 | ||||
self.data = data | self.data = data | ||||
self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
self.activation = activation | self.activation = activation | ||||
self.batched_data_pointer = batched_data_pointer | |||||
self.decoders = None | self.decoders = None | ||||
self.build() | self.build() | ||||
@@ -88,13 +98,16 @@ class DecodeLayer(torch.nn.Module): | |||||
tvt.append(dec(inputs_row, inputs_column, k)) | tvt.append(dec(inputs_row, inputs_column, k)) | ||||
tvt = TrainValTest(*tvt) | tvt = TrainValTest(*tvt) | ||||
pred.append(tvt) | pred.append(tvt) | ||||
print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||||
# print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||||
return pred | return pred | ||||
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | ||||
t = time.time() | t = time.time() | ||||
res = [] | res = [] | ||||
for i, fam in enumerate(self.data.relation_families): | |||||
data = self.batched_data_pointer.batched_data \ | |||||
if self.batched_data_pointer is not None \ | |||||
else self.data | |||||
for i, fam in enumerate(data.relation_families): | |||||
fam_pred = [] | fam_pred = [] | ||||
for k, r in enumerate(fam.relation_types): | for k, r in enumerate(fam.relation_types): | ||||
pred = [] | pred = [] | ||||
@@ -107,5 +120,5 @@ class DecodeLayer(torch.nn.Module): | |||||
fam_pred = RelationFamilyPredictions(fam_pred) | fam_pred = RelationFamilyPredictions(fam_pred) | ||||
res.append(fam_pred) | res.append(fam_pred) | ||||
res = Predictions(res) | res = Predictions(res) | ||||
print('DecodeLayer.forward() took', time.time() - t) | |||||
# print('DecodeLayer.forward() took', time.time() - t) | |||||
return res | return res |
@@ -1,9 +1,14 @@ | |||||
from icosagon.databatch import DataBatcher, \ | from icosagon.databatch import DataBatcher, \ | ||||
BatchedData | |||||
BatchedData, \ | |||||
BatchedDataPointer, \ | |||||
batched_data_skeleton | |||||
from icosagon.data import Data | from icosagon.data import Data | ||||
from icosagon.trainprep import prepare_training, \ | from icosagon.trainprep import prepare_training, \ | ||||
TrainValTest | TrainValTest | ||||
from icosagon.declayer import DecodeLayer | |||||
from icosagon.input import OneHotInputLayer | |||||
import torch | import torch | ||||
import time | |||||
def _some_data(): | def _some_data(): | ||||
@@ -16,6 +21,16 @@ def _some_data(): | |||||
return data | return data | ||||
def _some_data_big(): | |||||
data = Data() | |||||
data.add_node_type('Foo', 2000) | |||||
data.add_node_type('Bar', 2100) | |||||
fam = data.add_relation_family('Foo-Bar', 0, 1, True) | |||||
adj_mat = torch.rand(2000, 2100).round().to_sparse() | |||||
fam.add_relation_type('Foo-Bar', adj_mat) | |||||
return data | |||||
def test_data_batcher_01(): | def test_data_batcher_01(): | ||||
data = _some_data() | data = _some_data() | ||||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | ||||
@@ -79,3 +94,33 @@ def test_data_batcher_05(): | |||||
assert all([ len(edges) <= 512 for edges in edges_list ]) | assert all([ len(edges) <= 512 for edges in edges_list ]) | ||||
assert not all([ len(edges) == 0 for edges in edges_list ]) | assert not all([ len(edges) == 0 for edges in edges_list ]) | ||||
print(sum(map(len, edges_list))) | print(sum(map(len, edges_list))) | ||||
def test_batch_decode_01(): | |||||
data = _some_data() | |||||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||||
batcher = DataBatcher(prep_d, 512) | |||||
ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||||
in_repr = [ torch.rand(100, 32), | |||||
torch.rand(500, 32) ] | |||||
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||||
t = time.time() | |||||
for batched_data in batcher: | |||||
ptr.batched_data = batched_data | |||||
_ = dec_layer(in_repr) | |||||
print('Elapsed:', time.time() - t) | |||||
def test_batch_decode_02(): | |||||
data = _some_data_big() | |||||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||||
batcher = DataBatcher(prep_d, 512) | |||||
ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||||
in_repr = [ torch.rand(2000, 32), | |||||
torch.rand(2100, 32) ] | |||||
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||||
t = time.time() | |||||
for batched_data in batcher: | |||||
ptr.batched_data = batched_data | |||||
_ = dec_layer(in_repr) | |||||
print('Elapsed:', time.time() - t) |