IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Fix test_model.

master
Stanislaw Adaszewski pirms 4 gadiem
vecāks
revīzija
2b388e4431
3 mainītis faili ar 16 papildinājumiem un 32 dzēšanām
  1. +1
    -1
      src/icosagon/trainloop.py
  2. +0
    -31
      tests/icosagon/test_model.py
  3. +15
    -0
      tests/icosagon/test_trainloop.py

+ 1
- 1
src/icosagon/trainloop.py Parādīt failu

@@ -41,7 +41,7 @@ class TrainLoop(object):
loss_sum = 0
for i in range(n):
self.opt.zero_grad()
pred = self.seq(None)
pred = self.model(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
seed = torch.rand(1).item()
rng_state = torch.get_rng_state()


+ 0
- 31
tests/icosagon/test_model.py Parādīt failu

@@ -37,11 +37,7 @@ def test_model_01():
assert _is_identity_function(m.rel_activation)
assert m.layer_activation == torch.nn.functional.relu
assert _is_identity_function(m.dec_activation)
assert m.lr == 0.001
assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits
assert m.batch_size == 100
assert isinstance(m.seq, torch.nn.Sequential)
assert isinstance(m.opt, torch.optim.Optimizer)
def test_model_02():
@@ -83,15 +79,10 @@ def test_model_03():
m = Model(prep_d)
state_dict = m.opt.state_dict()
assert isinstance(state_dict, dict)
# print(state_dict['param_groups'])
# print(list(m.seq.parameters()))
assert len(list(m.seq[0].parameters())) == 1
assert len(list(m.seq[1].parameters())) == 1
assert len(list(m.seq[2].parameters())) == 1
assert len(list(m.seq[3].parameters())) == 2
# print(list(m.seq[1].parameters()))
def test_model_04():
@@ -189,25 +180,3 @@ def test_model_07():
with pytest.raises(TypeError):
m = Model(prep_d, dec_activation='x')
with pytest.raises(ValueError):
m = Model(prep_d, lr='x')
with pytest.raises(TypeError):
m = Model(prep_d, loss=1)
with pytest.raises(ValueError):
m = Model(prep_d, batch_size='x')
def test_model_08():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
m = Model(prep_d)
m.run_epoch()

+ 15
- 0
tests/icosagon/test_trainloop.py Parādīt failu

@@ -22,3 +22,18 @@ def test_train_loop_01():
assert loop.lr == 0.001
assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
assert loop.batch_size == 100
def test_train_loop_02():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
m = Model(prep_d)
loop = TrainLoop(m)
loop.run_epoch()

Notiek ielāde…
Atcelt
Saglabāt