IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add BatchedData support to DecodeLayer.

master
Stanislaw Adaszewski 3 years ago
parent
commit
5bc276eb6b
4 changed files with 68 additions and 5 deletions
  1. +1
    -1
      src/icosagon/convlayer.py
  2. +5
    -0
      src/icosagon/databatch.py
  3. +16
    -3
      src/icosagon/declayer.py
  4. +46
    -1
      tests/icosagon/test_databatch.py

+ 1
- 1
src/icosagon/convlayer.py View File

@@ -122,5 +122,5 @@ class DecagonLayer(torch.nn.Module):
next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row])
next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row])
print('DecagonLayer.forward() took', time.time() - t)
# print('DecagonLayer.forward() took', time.time() - t)
return next_layer_repr

+ 5
- 0
src/icosagon/databatch.py View File

@@ -11,6 +11,11 @@ class BatchedData(PreparedData):
super().__init__(*args, **kwargs)
class BatchedDataPointer(object):
def __init__(self, batched_data):
self.batched_data = batched_data
def batched_data_skeleton(data: PreparedData) -> BatchedData:
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')


+ 16
- 3
src/icosagon/declayer.py View File

@@ -17,6 +17,7 @@ from typing import Type, \
from .decode import DEDICOMDecoder
from dataclasses import dataclass
import time
from .databatch import BatchedDataPointer
@dataclass
@@ -43,6 +44,7 @@ class DecodeLayer(torch.nn.Module):
data: PreparedData,
keep_prob: float = 1.,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
batched_data_pointer: BatchedDataPointer = None,
**kwargs) -> None:
super().__init__(**kwargs)
@@ -59,11 +61,19 @@ class DecodeLayer(torch.nn.Module):
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if batched_data_pointer is not None and \
not isinstance(batched_data_pointer, BatchedDataPointer):
raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer')
# if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data):
# raise ValueError('batched_data_pointer must be compatible with data')
self.input_dim = input_dim[0]
self.output_dim = 1
self.data = data
self.keep_prob = keep_prob
self.activation = activation
self.batched_data_pointer = batched_data_pointer
self.decoders = None
self.build()
@@ -88,13 +98,16 @@ class DecodeLayer(torch.nn.Module):
tvt.append(dec(inputs_row, inputs_column, k))
tvt = TrainValTest(*tvt)
pred.append(tvt)
print('DecodeLayer._get_tvt() took:', time.time() - start_time)
# print('DecodeLayer._get_tvt() took:', time.time() - start_time)
return pred
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]:
t = time.time()
res = []
for i, fam in enumerate(self.data.relation_families):
data = self.batched_data_pointer.batched_data \
if self.batched_data_pointer is not None \
else self.data
for i, fam in enumerate(data.relation_families):
fam_pred = []
for k, r in enumerate(fam.relation_types):
pred = []
@@ -107,5 +120,5 @@ class DecodeLayer(torch.nn.Module):
fam_pred = RelationFamilyPredictions(fam_pred)
res.append(fam_pred)
res = Predictions(res)
print('DecodeLayer.forward() took', time.time() - t)
# print('DecodeLayer.forward() took', time.time() - t)
return res

+ 46
- 1
tests/icosagon/test_databatch.py View File

@@ -1,9 +1,14 @@
from icosagon.databatch import DataBatcher, \
BatchedData
BatchedData, \
BatchedDataPointer, \
batched_data_skeleton
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
from icosagon.declayer import DecodeLayer
from icosagon.input import OneHotInputLayer
import torch
import time
def _some_data():
@@ -16,6 +21,16 @@ def _some_data():
return data
def _some_data_big():
data = Data()
data.add_node_type('Foo', 2000)
data.add_node_type('Bar', 2100)
fam = data.add_relation_family('Foo-Bar', 0, 1, True)
adj_mat = torch.rand(2000, 2100).round().to_sparse()
fam.add_relation_type('Foo-Bar', adj_mat)
return data
def test_data_batcher_01():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
@@ -79,3 +94,33 @@ def test_data_batcher_05():
assert all([ len(edges) <= 512 for edges in edges_list ])
assert not all([ len(edges) == 0 for edges in edges_list ])
print(sum(map(len, edges_list)))
def test_batch_decode_01():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
in_repr = [ torch.rand(100, 32),
torch.rand(500, 32) ]
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
t = time.time()
for batched_data in batcher:
ptr.batched_data = batched_data
_ = dec_layer(in_repr)
print('Elapsed:', time.time() - t)
def test_batch_decode_02():
data = _some_data_big()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
in_repr = [ torch.rand(2000, 32),
torch.rand(2100, 32) ]
dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
t = time.time()
for batched_data in batcher:
ptr.batched_data = batched_data
_ = dec_layer(in_repr)
print('Elapsed:', time.time() - t)

Loading…
Cancel
Save