From 56ce7aa60b980ec768e33f443a73ab5f41a8159f Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 9 Jun 2020 13:07:08 +0200 Subject: [PATCH] Add type checks to DecodeLayer. --- src/icosagon/declayer.py | 20 +++++++++++++++++--- src/icosagon/trainprep.py | 4 ++-- tests/icosagon/test_declayer.py | 15 +++++++++++---- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index f5ed5b1..9047025 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -20,7 +20,7 @@ from .decode import DEDICOMDecoder class DecodeLayer(torch.nn.Module): def __init__(self, input_dim: List[int], - data: Union[Data, PreparedData], + data: PreparedData, keep_prob: float = 1., decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, @@ -28,8 +28,22 @@ class DecodeLayer(torch.nn.Module): super().__init__(**kwargs) - assert all([ a == input_dim[0] \ - for a in input_dim ]) + if not isinstance(input_dim, list): + raise TypeError('input_dim must be a List') + + if not all([ a == input_dim[0] for a in input_dim ]): + raise ValueError('All elements of input_dim must have the same value') + + if not isinstance(data, PreparedData): + raise TypeError('data must be an instance of PreparedData') + + if not isinstance(decoder_class, type) and \ + not isinstance(decoder_class, dict): + raise TypeError('decoder_class must be a Type or a Dict') + + if not isinstance(decoder_class, dict): + decoder_class = { k: decoder_class \ + for k in data.relation_types.keys() } self.input_dim = input_dim self.output_dim = 1 diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index a979290..a505d9b 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -133,7 +133,7 @@ def prepare_relation_type(r: RelationType, adj_mat_train, edges_pos, edges_neg) -def prepare_training(data: Data) -> PreparedData: +def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: if not isinstance(data, Data): raise ValueError('data must be of class Data') @@ -141,5 +141,5 @@ def prepare_training(data: Data) -> PreparedData: for (node_type_row, node_type_column), rels in data.relation_types.items(): for r in rels: relation_types[node_type_row, node_type_column].append( - prep_relation_type(r)) + prepare_relation_type(r, ratios)) return PreparedData(data.node_types, relation_types) diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index 2649759..3536387 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -9,6 +9,8 @@ from icosagon.convlayer import DecagonLayer from icosagon.declayer import DecodeLayer from icosagon.decode import DEDICOMDecoder from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest import torch @@ -17,11 +19,12 @@ def test_decode_layer_01(): d.add_node_type('Dummy', 100) d.add_relation_type('Dummy Relation 1', 0, 0, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, 32, d) seq = torch.nn.Sequential(in_layer, d_layer) last_layer_repr = seq(None) - dec = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1., + dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., decoder_class=DEDICOMDecoder, activation=lambda x: x) pred_adj_matrices = dec(last_layer_repr) assert isinstance(pred_adj_matrices, dict) @@ -35,10 +38,11 @@ def test_decode_layer_02(): d.add_node_type('Dummy', 100) d.add_relation_type('Dummy Relation 1', 0, 0, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, 32, d) - dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1., + dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., decoder_class=DEDICOMDecoder, activation=lambda x: x) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) @@ -56,10 +60,11 @@ def test_decode_layer_03(): d.add_node_type('Dummy 2', 100) d.add_relation_type('Dummy Relation 1', 0, 1, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, 32, d) - dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1., + dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) @@ -77,9 +82,11 @@ def test_decode_layer_04(): d.add_node_type('Dummy', 100) assert len(d.relation_types[0, 0]) == 0 + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, 32, d) - dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1., + dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., decoder_class=DEDICOMDecoder, activation=lambda x: x) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)