@@ -122,5 +122,5 @@ class DecagonLayer(torch.nn.Module): | |||
next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) | |||
next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) | |||
print('DecagonLayer.forward() took', time.time() - t) | |||
# print('DecagonLayer.forward() took', time.time() - t) | |||
return next_layer_repr |
@@ -11,6 +11,11 @@ class BatchedData(PreparedData): | |||
super().__init__(*args, **kwargs) | |||
class BatchedDataPointer(object): | |||
def __init__(self, batched_data): | |||
self.batched_data = batched_data | |||
def batched_data_skeleton(data: PreparedData) -> BatchedData: | |||
if not isinstance(data, PreparedData): | |||
raise TypeError('data must be an instance of PreparedData') | |||
@@ -17,6 +17,7 @@ from typing import Type, \ | |||
from .decode import DEDICOMDecoder | |||
from dataclasses import dataclass | |||
import time | |||
from .databatch import BatchedDataPointer | |||
@dataclass | |||
@@ -43,6 +44,7 @@ class DecodeLayer(torch.nn.Module): | |||
data: PreparedData, | |||
keep_prob: float = 1., | |||
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
batched_data_pointer: BatchedDataPointer = None, | |||
**kwargs) -> None: | |||
super().__init__(**kwargs) | |||
@@ -59,11 +61,19 @@ class DecodeLayer(torch.nn.Module): | |||
if not isinstance(data, PreparedData): | |||
raise TypeError('data must be an instance of PreparedData') | |||
if batched_data_pointer is not None and \ | |||
not isinstance(batched_data_pointer, BatchedDataPointer): | |||
raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer') | |||
# if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data): | |||
# raise ValueError('batched_data_pointer must be compatible with data') | |||
self.input_dim = input_dim[0] | |||
self.output_dim = 1 | |||
self.data = data | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
self.batched_data_pointer = batched_data_pointer | |||
self.decoders = None | |||
self.build() | |||
@@ -88,13 +98,16 @@ class DecodeLayer(torch.nn.Module): | |||
tvt.append(dec(inputs_row, inputs_column, k)) | |||
tvt = TrainValTest(*tvt) | |||
pred.append(tvt) | |||
print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
# print('DecodeLayer._get_tvt() took:', time.time() - start_time) | |||
return pred | |||
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||
t = time.time() | |||
res = [] | |||
for i, fam in enumerate(self.data.relation_families): | |||
data = self.batched_data_pointer.batched_data \ | |||
if self.batched_data_pointer is not None \ | |||
else self.data | |||
for i, fam in enumerate(data.relation_families): | |||
fam_pred = [] | |||
for k, r in enumerate(fam.relation_types): | |||
pred = [] | |||
@@ -107,5 +120,5 @@ class DecodeLayer(torch.nn.Module): | |||
fam_pred = RelationFamilyPredictions(fam_pred) | |||
res.append(fam_pred) | |||
res = Predictions(res) | |||
print('DecodeLayer.forward() took', time.time() - t) | |||
# print('DecodeLayer.forward() took', time.time() - t) | |||
return res |
@@ -1,9 +1,14 @@ | |||
from icosagon.databatch import DataBatcher, \ | |||
BatchedData | |||
BatchedData, \ | |||
BatchedDataPointer, \ | |||
batched_data_skeleton | |||
from icosagon.data import Data | |||
from icosagon.trainprep import prepare_training, \ | |||
TrainValTest | |||
from icosagon.declayer import DecodeLayer | |||
from icosagon.input import OneHotInputLayer | |||
import torch | |||
import time | |||
def _some_data(): | |||
@@ -16,6 +21,16 @@ def _some_data(): | |||
return data | |||
def _some_data_big(): | |||
data = Data() | |||
data.add_node_type('Foo', 2000) | |||
data.add_node_type('Bar', 2100) | |||
fam = data.add_relation_family('Foo-Bar', 0, 1, True) | |||
adj_mat = torch.rand(2000, 2100).round().to_sparse() | |||
fam.add_relation_type('Foo-Bar', adj_mat) | |||
return data | |||
def test_data_batcher_01(): | |||
data = _some_data() | |||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
@@ -79,3 +94,33 @@ def test_data_batcher_05(): | |||
assert all([ len(edges) <= 512 for edges in edges_list ]) | |||
assert not all([ len(edges) == 0 for edges in edges_list ]) | |||
print(sum(map(len, edges_list))) | |||
def test_batch_decode_01(): | |||
data = _some_data() | |||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
batcher = DataBatcher(prep_d, 512) | |||
ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||
in_repr = [ torch.rand(100, 32), | |||
torch.rand(500, 32) ] | |||
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||
t = time.time() | |||
for batched_data in batcher: | |||
ptr.batched_data = batched_data | |||
_ = dec_layer(in_repr) | |||
print('Elapsed:', time.time() - t) | |||
def test_batch_decode_02(): | |||
data = _some_data_big() | |||
prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
batcher = DataBatcher(prep_d, 512) | |||
ptr = BatchedDataPointer(batched_data_skeleton(prep_d)) | |||
in_repr = [ torch.rand(2000, 32), | |||
torch.rand(2100, 32) ] | |||
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr) | |||
t = time.time() | |||
for batched_data in batcher: | |||
ptr.batched_data = batched_data | |||
_ = dec_layer(in_repr) | |||
print('Elapsed:', time.time() - t) |