From d5779daac3c76a9b8c90a07ed73ff1ac59ec28d3 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 17 Jun 2020 11:30:07 +0200 Subject: [PATCH] Add test_model_01(). --- src/icosagon/model.py | 11 ++++++++--- tests/icosagon/test_model.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 tests/icosagon/test_model.py diff --git a/src/icosagon/model.py b/src/icosagon/model.py index 4c1cf0d..5d384f0 100644 --- a/src/icosagon/model.py +++ b/src/icosagon/model.py @@ -1,5 +1,6 @@ from .data import Data -from typing import List +from typing import List, \ + Callable from .trainprep import prepare_training, \ TrainValTest import torch @@ -19,13 +20,13 @@ class Model(object): layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, lr: float = 0.001, - loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, batch_size: int = 100) -> None: if not isinstance(data, Data): raise TypeError('data must be an instance of Data') - if not isinstance(layer_sizes, list): + if not isinstance(layer_dimensions, list): raise TypeError('layer_dimensions must be a list') if not isinstance(ratios, TrainValTest): @@ -60,6 +61,10 @@ class Model(object): self.loss = loss self.batch_size = batch_size + self.prep_d = None + self.seq = None + self.opt = None + self.build() def build(self): diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py new file mode 100644 index 0000000..666960c --- /dev/null +++ b/tests/icosagon/test_model.py @@ -0,0 +1,35 @@ +from icosagon.data import Data +from icosagon.model import Model +from icosagon.trainprep import PreparedData +import torch +import ast + + +def _is_identity_function(f): + for x in range(-100, 101): + if f(x) != x: + return False + return True + + +def test_model_01(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) + + m = Model(d) + + assert m.data == d + assert m.layer_dimensions == [32, 64] + assert (m.ratios.train, m.ratios.val, m.ratios.test) == (.8, .1, .1) + assert m.keep_prob == 1. + assert _is_identity_function(m.rel_activation) + assert m.layer_activation == torch.nn.functional.relu + assert _is_identity_function(m.dec_activation) + assert m.lr == 0.001 + assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits + assert m.batch_size == 100 + assert isinstance(m.prep_d, PreparedData) + assert isinstance(m.seq, torch.nn.Sequential) + assert isinstance(m.opt, torch.optim.Optimizer)