IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.1KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import prepare_training, \
  3. TrainValTest
  4. from icosagon.model import Model
  5. from icosagon.trainloop import TrainLoop
  6. import torch
  7. import pytest
  8. import pdb
  9. import time
  10. def test_train_loop_01():
  11. d = Data()
  12. d.add_node_type('Dummy', 10)
  13. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  14. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  15. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  16. m = Model(prep_d)
  17. loop = TrainLoop(m)
  18. assert loop.model == m
  19. assert loop.lr == 0.001
  20. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  21. assert loop.batch_size == 100
  22. def test_train_loop_02():
  23. d = Data()
  24. d.add_node_type('Dummy', 10)
  25. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  26. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  27. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  28. m = Model(prep_d)
  29. loop = TrainLoop(m)
  30. loop.run_epoch()
  31. def test_train_loop_03():
  32. # pdb.set_trace()
  33. if torch.cuda.device_count() == 0:
  34. pytest.skip('CUDA required for this test')
  35. adj_mat = torch.rand(10, 10).round()
  36. dev = torch.device('cuda:0')
  37. adj_mat = adj_mat.to(dev)
  38. d = Data()
  39. d.add_node_type('Dummy', 10)
  40. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  41. fam.add_relation_type('Dummy Rel', adj_mat)
  42. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  43. # pdb.set_trace()
  44. m = Model(prep_d)
  45. m = m.to(dev)
  46. print(list(m.parameters()))
  47. for prm in m.parameters():
  48. assert prm.device == dev
  49. loop = TrainLoop(m)
  50. loop.run_epoch()
  51. def test_timing_01():
  52. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  53. rep = torch.eye(2000).requires_grad_(True)
  54. t = time.time()
  55. for _ in range(1300):
  56. _ = torch.sparse.mm(adj_mat, rep)
  57. print('Elapsed:', time.time() - t)
  58. def test_timing_02():
  59. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  60. adj_mat_batch = [adj_mat.view(1, 2000, 2000)] * 1300
  61. adj_mat_batch = torch.cat(adj_mat_batch)
  62. rep = torch.eye(2000).requires_grad_(True)
  63. t = time.time()
  64. res = torch.matmul(adj_mat_batch, rep)
  65. print('Elapsed:', time.time() - t)
  66. print('res.shape:', res.shape)
  67. def test_timing_03():
  68. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  69. adj_mat_batch = [adj_mat.view(1, 2000, 2000).to_sparse()] * 1300
  70. adj_mat_batch = torch.cat(adj_mat_batch)
  71. rep = torch.eye(2000).requires_grad_(True)
  72. rep_batch = [rep.view(1, 2000, 2000)] * 1300
  73. rep_batch = torch.cat(rep_batch)
  74. t = time.time()
  75. with pytest.raises(RuntimeError):
  76. _ = torch.bmm(adj_mat_batch, rep)
  77. print('Elapsed:', time.time() - t)
  78. def test_timing_04():
  79. adj_mat = (torch.rand(2000, 2000) < .0001).to(torch.float32).to_sparse()
  80. rep = torch.eye(2000).requires_grad_(True)
  81. t = time.time()
  82. for _ in range(1300):
  83. _ = torch.sparse.mm(adj_mat, rep)
  84. print('Elapsed:', time.time() - t)