IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

Add type checks to DecodeLayer.

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
56ce7aa60b
3 geänderte Dateien mit 30 neuen und 9 gelöschten Zeilen
  1. +17
    -3
      src/icosagon/declayer.py
  2. +2
    -2
      src/icosagon/trainprep.py
  3. +11
    -4
      tests/icosagon/test_declayer.py

+ 17
- 3
src/icosagon/declayer.py Datei anzeigen

@@ -20,7 +20,7 @@ from .decode import DEDICOMDecoder
class DecodeLayer(torch.nn.Module):
def __init__(self,
input_dim: List[int],
data: Union[Data, PreparedData],
data: PreparedData,
keep_prob: float = 1.,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
@@ -28,8 +28,22 @@ class DecodeLayer(torch.nn.Module):
super().__init__(**kwargs)
assert all([ a == input_dim[0] \
for a in input_dim ])
if not isinstance(input_dim, list):
raise TypeError('input_dim must be a List')
if not all([ a == input_dim[0] for a in input_dim ]):
raise ValueError('All elements of input_dim must have the same value')
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if not isinstance(decoder_class, type) and \
not isinstance(decoder_class, dict):
raise TypeError('decoder_class must be a Type or a Dict')
if not isinstance(decoder_class, dict):
decoder_class = { k: decoder_class \
for k in data.relation_types.keys() }
self.input_dim = input_dim
self.output_dim = 1


+ 2
- 2
src/icosagon/trainprep.py Datei anzeigen

@@ -133,7 +133,7 @@ def prepare_relation_type(r: RelationType,
adj_mat_train, edges_pos, edges_neg)
def prepare_training(data: Data) -> PreparedData:
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data):
raise ValueError('data must be of class Data')
@@ -141,5 +141,5 @@ def prepare_training(data: Data) -> PreparedData:
for (node_type_row, node_type_column), rels in data.relation_types.items():
for r in rels:
relation_types[node_type_row, node_type_column].append(
prep_relation_type(r))
prepare_relation_type(r, ratios))
return PreparedData(data.node_types, relation_types)

+ 11
- 4
tests/icosagon/test_declayer.py Datei anzeigen

@@ -9,6 +9,8 @@ from icosagon.convlayer import DecagonLayer
from icosagon.declayer import DecodeLayer
from icosagon.decode import DEDICOMDecoder
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
import torch
@@ -17,11 +19,12 @@ def test_decode_layer_01():
d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d)
seq = torch.nn.Sequential(in_layer, d_layer)
last_layer_repr = seq(None)
dec = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x)
pred_adj_matrices = dec(last_layer_repr)
assert isinstance(pred_adj_matrices, dict)
@@ -35,10 +38,11 @@ def test_decode_layer_02():
d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -56,10 +60,11 @@ def test_decode_layer_03():
d.add_node_type('Dummy 2', 100)
d.add_relation_type('Dummy Relation 1', 0, 1,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -77,9 +82,11 @@ def test_decode_layer_04():
d.add_node_type('Dummy', 100)
assert len(d.relation_types[0, 0]) == 0
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)


Laden…
Abbrechen
Speichern