IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.0KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import prepare_training, \
  3. TrainValTest
  4. from icosagon.model import Model
  5. from icosagon.trainloop import TrainLoop
  6. import torch
  7. def test_train_loop_01():
  8. d = Data()
  9. d.add_node_type('Dummy', 10)
  10. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  11. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  12. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  13. m = Model(prep_d)
  14. loop = TrainLoop(m)
  15. assert loop.model == m
  16. assert loop.lr == 0.001
  17. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  18. assert loop.batch_size == 100
  19. def test_train_loop_02():
  20. d = Data()
  21. d.add_node_type('Dummy', 10)
  22. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  23. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  24. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  25. m = Model(prep_d)
  26. loop = TrainLoop(m)
  27. loop.run_epoch()