IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_model_07().

master
Stanislaw Adaszewski 4 years ago
parent
commit
f17f97caf4
2 changed files with 15 additions and 4 deletions
  1. +4
    -4
      src/icosagon/model.py
  2. +11
    -0
      tests/icosagon/test_model.py

+ 4
- 4
src/icosagon/model.py View File

@@ -99,13 +99,13 @@ class Model(object):
def run_epoch(self): def run_epoch(self):
pred = self.seq(None) pred = self.seq(None)
batch = PredictionsBatch(pred, self.batch_size)
batch = PredictionsBatch(pred, batch_size=self.batch_size)
n = len(list(iter(batch))) n = len(list(iter(batch)))
loss_sum = 0 loss_sum = 0
for i in range(n - 1):
for i in range(n):
self.opt.zero_grad() self.opt.zero_grad()
pred = self.seq(None) pred = self.seq(None)
batch = PredictionsBatch(pred, self.batch_size, shuffle=True)
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
seed = torch.rand(1).item() seed = torch.rand(1).item()
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
torch.manual_seed(seed) torch.manual_seed(seed)
@@ -116,7 +116,7 @@ class Model(object):
(input, target) = next(it) (input, target) = next(it)
loss = self.loss(input, target) loss = self.loss(input, target)
loss.backward() loss.backward()
self.opt.optimize()
self.opt.step()
loss_sum += loss.detach().cpu().item() loss_sum += loss.detach().cpu().item()
return loss_sum return loss_sum


+ 11
- 0
tests/icosagon/test_model.py View File

@@ -185,3 +185,14 @@ def test_model_06():
with pytest.raises(ValueError): with pytest.raises(ValueError):
m = Model(d, batch_size='x') m = Model(d, batch_size='x')
def test_model_07():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
m = Model(d)
m.run_epoch()

Loading…
Cancel
Save