IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

Add test_model_01().

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
d5779daac3
2 fichiers modifiés avec 43 ajouts et 3 suppressions
  1. +8
    -3
      src/icosagon/model.py
  2. +35
    -0
      tests/icosagon/test_model.py

+ 8
- 3
src/icosagon/model.py Voir le fichier

@@ -1,5 +1,6 @@
from .data import Data
from typing import List
from typing import List, \
Callable
from .trainprep import prepare_training, \
TrainValTest
import torch
@@ -19,13 +20,13 @@ class Model(object):
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
lr: float = 0.001,
loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100) -> None:
if not isinstance(data, Data):
raise TypeError('data must be an instance of Data')
if not isinstance(layer_sizes, list):
if not isinstance(layer_dimensions, list):
raise TypeError('layer_dimensions must be a list')
if not isinstance(ratios, TrainValTest):
@@ -60,6 +61,10 @@ class Model(object):
self.loss = loss
self.batch_size = batch_size
self.prep_d = None
self.seq = None
self.opt = None
self.build()
def build(self):


+ 35
- 0
tests/icosagon/test_model.py Voir le fichier

@@ -0,0 +1,35 @@
from icosagon.data import Data
from icosagon.model import Model
from icosagon.trainprep import PreparedData
import torch
import ast
def _is_identity_function(f):
for x in range(-100, 101):
if f(x) != x:
return False
return True
def test_model_01():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
m = Model(d)
assert m.data == d
assert m.layer_dimensions == [32, 64]
assert (m.ratios.train, m.ratios.val, m.ratios.test) == (.8, .1, .1)
assert m.keep_prob == 1.
assert _is_identity_function(m.rel_activation)
assert m.layer_activation == torch.nn.functional.relu
assert _is_identity_function(m.dec_activation)
assert m.lr == 0.001
assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits
assert m.batch_size == 100
assert isinstance(m.prep_d, PreparedData)
assert isinstance(m.seq, torch.nn.Sequential)
assert isinstance(m.opt, torch.optim.Optimizer)

Chargement…
Annuler
Enregistrer