IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Add type checks to DecodeLayer.

master
Stanislaw Adaszewski pirms 3 gadiem
vecāks
revīzija
56ce7aa60b
3 mainītis faili ar 30 papildinājumiem un 9 dzēšanām
  1. +17
    -3
      src/icosagon/declayer.py
  2. +2
    -2
      src/icosagon/trainprep.py
  3. +11
    -4
      tests/icosagon/test_declayer.py

+ 17
- 3
src/icosagon/declayer.py Parādīt failu

@@ -20,7 +20,7 @@ from .decode import DEDICOMDecoder
class DecodeLayer(torch.nn.Module): class DecodeLayer(torch.nn.Module):
def __init__(self, def __init__(self,
input_dim: List[int], input_dim: List[int],
data: Union[Data, PreparedData],
data: PreparedData,
keep_prob: float = 1., keep_prob: float = 1.,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
@@ -28,8 +28,22 @@ class DecodeLayer(torch.nn.Module):
super().__init__(**kwargs) super().__init__(**kwargs)
assert all([ a == input_dim[0] \
for a in input_dim ])
if not isinstance(input_dim, list):
raise TypeError('input_dim must be a List')
if not all([ a == input_dim[0] for a in input_dim ]):
raise ValueError('All elements of input_dim must have the same value')
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if not isinstance(decoder_class, type) and \
not isinstance(decoder_class, dict):
raise TypeError('decoder_class must be a Type or a Dict')
if not isinstance(decoder_class, dict):
decoder_class = { k: decoder_class \
for k in data.relation_types.keys() }
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = 1 self.output_dim = 1


+ 2
- 2
src/icosagon/trainprep.py Parādīt failu

@@ -133,7 +133,7 @@ def prepare_relation_type(r: RelationType,
adj_mat_train, edges_pos, edges_neg) adj_mat_train, edges_pos, edges_neg)
def prepare_training(data: Data) -> PreparedData:
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data): if not isinstance(data, Data):
raise ValueError('data must be of class Data') raise ValueError('data must be of class Data')
@@ -141,5 +141,5 @@ def prepare_training(data: Data) -> PreparedData:
for (node_type_row, node_type_column), rels in data.relation_types.items(): for (node_type_row, node_type_column), rels in data.relation_types.items():
for r in rels: for r in rels:
relation_types[node_type_row, node_type_column].append( relation_types[node_type_row, node_type_column].append(
prep_relation_type(r))
prepare_relation_type(r, ratios))
return PreparedData(data.node_types, relation_types) return PreparedData(data.node_types, relation_types)

+ 11
- 4
tests/icosagon/test_declayer.py Parādīt failu

@@ -9,6 +9,8 @@ from icosagon.convlayer import DecagonLayer
from icosagon.declayer import DecodeLayer from icosagon.declayer import DecodeLayer
from icosagon.decode import DEDICOMDecoder from icosagon.decode import DEDICOMDecoder
from icosagon.data import Data from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
import torch import torch
@@ -17,11 +19,12 @@ def test_decode_layer_01():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0, d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
seq = torch.nn.Sequential(in_layer, d_layer) seq = torch.nn.Sequential(in_layer, d_layer)
last_layer_repr = seq(None) last_layer_repr = seq(None)
dec = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
pred_adj_matrices = dec(last_layer_repr) pred_adj_matrices = dec(last_layer_repr)
assert isinstance(pred_adj_matrices, dict) assert isinstance(pred_adj_matrices, dict)
@@ -35,10 +38,11 @@ def test_decode_layer_02():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0, d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -56,10 +60,11 @@ def test_decode_layer_03():
d.add_node_type('Dummy 2', 100) d.add_node_type('Dummy 2', 100)
d.add_relation_type('Dummy Relation 1', 0, 1, d.add_relation_type('Dummy Relation 1', 0, 1,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x) decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
@@ -77,9 +82,11 @@ def test_decode_layer_04():
d.add_node_type('Dummy', 100) d.add_node_type('Dummy', 100)
assert len(d.relation_types[0, 0]) == 0 assert len(d.relation_types[0, 0]) == 0
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d) in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d) d_layer = DecagonLayer(in_layer.output_dim, 32, d)
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=d, keep_prob=1.,
dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x) decoder_class=DEDICOMDecoder, activation=lambda x: x)
seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)


Notiek ielāde…
Atcelt
Saglabāt