IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add type checks to DecodeLayer.

master
Stanislaw Adaszewski 4 years ago
parent
commit
56ce7aa60b
3 changed files with 30 additions and 9 deletions
  1. +17
    -3
      src/icosagon/declayer.py
  2. +2
    -2
      src/icosagon/trainprep.py
  3. +11
    -4
      tests/icosagon/test_declayer.py

+ 17
- 3
src/icosagon/declayer.py View File

@@ -20,7 +20,7 @@ from .decode import DEDICOMDecoder
class DecodeLayer(torch.nn.Module): class DecodeLayer(torch.nn.Module):
def __init__(self, def __init__(self,
input_dim: List[int], input_dim: List[int],
data: Union[Data, PreparedData],
data: PreparedData,
keep_prob: float = 1., keep_prob: float = 1.,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
@@ -28,8 +28,22 @@ class DecodeLayer(torch.nn.Module):
super().__init__(**kwargs) super().__init__(**kwargs)
assert all([ a == input_dim[0] \
for a in input_dim ])
if not isinstance(input_dim, list):
raise TypeError('input_dim must be a List')
if not all([ a == input_dim[0] for a in input_dim ]):
raise ValueError('All elements of input_dim must have the same value')
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if not isinstance(decoder_class, type) and \
not isinstance(decoder_class, dict):
raise TypeError('decoder_class must be a Type or a Dict')
if not isinstance(decoder_class, dict):
decoder_class = { k: decoder_class \
for k in data.relation_types.keys() }
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = 1 self.output_dim = 1


+ 2
- 2
src/icosagon/trainprep.py View File

@@ -133,7 +133,7 @@ def prepare_relation_type(r: RelationType,
adj_mat_train, edges_pos, edges_neg) adj_mat_train, edges_pos, edges_neg)
def prepare_training(data: Data) -> PreparedData:
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data): if not isinstance(data, Data):
raise ValueError('data must be of class Data') raise ValueError('data must be of class Data')
@@ -141,5 +141,5 @@ def prepare_training(data: Data) -> PreparedData:
for (node_type_row, node_type_column), rels in data.relation_types.items(): for (node_type_row, node_type_column), rels in data.relation_types.items():
for r in rels: for r in rels:
relation_types[node_type_row, node_type_column].append( relation_types[node_type_row, node_type_column].append(
prep_relation_type(r))
prepare_relation_type(r, ratios))
return PreparedData(data.node_types, relation_types) return PreparedData(data.node_types, relation_types)

+ 11
- 4
tests/icosagon/test_declayer.py View File

@@ -9,6 +9,8 @@ from icosagon.convlayer import DecagonLayer
from icosagon.declayer import DecodeLayer from icosagon.declayer import DecodeLayer
from icosagon.decode import DEDICOMDecoder from icosagon.decode import DEDICOMDecoder
from icosagon.data import Data from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
import torch import torch
@@ -17,11 +19,12 @@ def test_decode_layer_01():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0, d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
seq = torch.nn.Sequential(in_layer, d_layer) seq = torch.nn.Sequential(in_layer, d_layer)
last_layer_repr = seq(None) last_layer_repr = seq(None)
dec = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
pred_adj_matrices = dec(last_layer_repr) pred_adj_matrices = dec(last_layer_repr)
assert isinstance(pred_adj_matrices, dict) assert isinstance(pred_adj_matrices, dict)
@@ -35,10 +38,11 @@ def test_decode_layer_02():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0, d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -56,10 +60,11 @@ def test_decode_layer_03():
d.add_node_type('Dummy 2', 100) d.add_node_type('Dummy 2', 100)
d.add_relation_type('Dummy Relation 1', 0, 1, d.add_relation_type('Dummy Relation 1', 0, 1,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x) decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -77,9 +82,11 @@ def test_decode_layer_04():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
assert len(d.relation_types[0, 0]) == 0 assert len(d.relation_types[0, 0]) == 0
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)


Loading…
Cancel
Save