IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
1.7KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import prepare_training, \
  3. TrainValTest
  4. from icosagon.model import Model
  5. from icosagon.trainloop import TrainLoop
  6. import torch
  7. import pytest
  8. import pdb
  9. def test_train_loop_01():
  10. d = Data()
  11. d.add_node_type('Dummy', 10)
  12. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  14. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  15. m = Model(prep_d)
  16. loop = TrainLoop(m)
  17. assert loop.model == m
  18. assert loop.lr == 0.001
  19. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  20. assert loop.batch_size == 100
  21. def test_train_loop_02():
  22. d = Data()
  23. d.add_node_type('Dummy', 10)
  24. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  25. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  26. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  27. m = Model(prep_d)
  28. loop = TrainLoop(m)
  29. loop.run_epoch()
  30. def test_train_loop_03():
  31. if torch.cuda.device_count() == 0:
  32. pytest.skip('CUDA required for this test')
  33. adj_mat = torch.rand(10, 10).round()
  34. dev = torch.device('cuda:0')
  35. adj_mat = adj_mat.to(dev)
  36. d = Data()
  37. d.add_node_type('Dummy', 10)
  38. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  39. fam.add_relation_type('Dummy Rel', adj_mat)
  40. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  41. # pdb.set_trace()
  42. m = Model(prep_d)
  43. m = m.to(dev)
  44. print(list(m.parameters()))
  45. for prm in m.parameters():
  46. assert prm.device == dev
  47. loop = TrainLoop(m)
  48. loop.run_epoch()