IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

82 wiersze
2.0KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import prepare_training, \
  3. TrainValTest
  4. from icosagon.model import Model
  5. from icosagon.trainloop import TrainLoop
  6. import torch
  7. import pytest
  8. import pdb
  9. import time
  10. def test_train_loop_01():
  11. d = Data()
  12. d.add_node_type('Dummy', 10)
  13. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  14. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  15. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  16. m = Model(prep_d)
  17. loop = TrainLoop(m)
  18. assert loop.model == m
  19. assert loop.lr == 0.001
  20. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  21. assert loop.batch_size == 100
  22. def test_train_loop_02():
  23. d = Data()
  24. d.add_node_type('Dummy', 10)
  25. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  26. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  27. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  28. m = Model(prep_d)
  29. loop = TrainLoop(m)
  30. loop.run_epoch()
  31. def test_train_loop_03():
  32. # pdb.set_trace()
  33. if torch.cuda.device_count() == 0:
  34. pytest.skip('CUDA required for this test')
  35. adj_mat = torch.rand(10, 10).round()
  36. dev = torch.device('cuda:0')
  37. adj_mat = adj_mat.to(dev)
  38. d = Data()
  39. d.add_node_type('Dummy', 10)
  40. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  41. fam.add_relation_type('Dummy Rel', adj_mat)
  42. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  43. # pdb.set_trace()
  44. m = Model(prep_d)
  45. m = m.to(dev)
  46. print(list(m.parameters()))
  47. for prm in m.parameters():
  48. assert prm.device == dev
  49. loop = TrainLoop(m)
  50. loop.run_epoch()
  51. def test_timing_01():
  52. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  53. rep = torch.eye(2000).requires_grad_(True)
  54. t = time.time()
  55. for _ in range(1300):
  56. _ = torch.sparse.mm(adj_mat, rep)
  57. print('Elapsed:', time.time() - t)