IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

134 行
3.8KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import prepare_training, \
  3. TrainValTest
  4. from icosagon.model import Model
  5. from icosagon.trainloop import TrainLoop
  6. import torch
  7. import pytest
  8. import pdb
  9. import time
  10. def test_train_loop_01():
  11. d = Data()
  12. d.add_node_type('Dummy', 10)
  13. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  14. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  15. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  16. m = Model(prep_d)
  17. loop = TrainLoop(m)
  18. assert loop.model == m
  19. assert loop.lr == 0.001
  20. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  21. assert loop.batch_size == 100
  22. def test_train_loop_02():
  23. d = Data()
  24. d.add_node_type('Dummy', 10)
  25. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  26. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  27. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  28. m = Model(prep_d)
  29. for prm in m.parameters():
  30. print(prm.shape, prm.is_leaf, prm.requires_grad)
  31. loop = TrainLoop(m)
  32. loop.run_epoch()
  33. for prm in m.parameters():
  34. print(prm.shape, prm.is_leaf, prm.requires_grad)
  35. def test_train_loop_03():
  36. # pdb.set_trace()
  37. if torch.cuda.device_count() == 0:
  38. pytest.skip('CUDA required for this test')
  39. adj_mat = torch.rand(10, 10).round()
  40. dev = torch.device('cuda:0')
  41. adj_mat = adj_mat.to(dev)
  42. d = Data()
  43. d.add_node_type('Dummy', 10)
  44. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  45. fam.add_relation_type('Dummy Rel', adj_mat)
  46. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  47. # pdb.set_trace()
  48. m = Model(prep_d)
  49. m = m.to(dev)
  50. print(list(m.parameters()))
  51. for prm in m.parameters():
  52. assert prm.device == dev
  53. loop = TrainLoop(m)
  54. loop.run_epoch()
  55. def test_timing_01():
  56. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  57. rep = torch.eye(2000).requires_grad_(True)
  58. t = time.time()
  59. for _ in range(1300):
  60. _ = torch.sparse.mm(adj_mat, rep)
  61. print('Elapsed:', time.time() - t)
  62. def test_timing_02():
  63. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  64. adj_mat_batch = [adj_mat.view(1, 2000, 2000)] * 1300
  65. adj_mat_batch = torch.cat(adj_mat_batch)
  66. rep = torch.eye(2000).requires_grad_(True)
  67. t = time.time()
  68. res = torch.matmul(adj_mat_batch, rep)
  69. print('Elapsed:', time.time() - t)
  70. print('res.shape:', res.shape)
  71. def test_timing_03():
  72. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  73. adj_mat_batch = [adj_mat.view(1, 2000, 2000).to_sparse()] * 1300
  74. adj_mat_batch = torch.cat(adj_mat_batch)
  75. rep = torch.eye(2000).requires_grad_(True)
  76. rep_batch = [rep.view(1, 2000, 2000)] * 1300
  77. rep_batch = torch.cat(rep_batch)
  78. t = time.time()
  79. with pytest.raises(RuntimeError):
  80. _ = torch.bmm(adj_mat_batch, rep)
  81. print('Elapsed:', time.time() - t)
  82. def test_timing_04():
  83. adj_mat = (torch.rand(2000, 2000) < .0001).to(torch.float32).to_sparse()
  84. rep = torch.eye(2000).requires_grad_(True)
  85. t = time.time()
  86. for _ in range(1300):
  87. _ = torch.sparse.mm(adj_mat, rep)
  88. print('Elapsed:', time.time() - t)
  89. def test_timing_05():
  90. if torch.cuda.device_count() == 0:
  91. pytest.skip('Test requires CUDA')
  92. dev = torch.device('cuda:0')
  93. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse().to(dev)
  94. rep = torch.eye(2000).requires_grad_(True).to(dev)
  95. t = time.time()
  96. for _ in range(1300):
  97. _ = torch.sparse.mm(adj_mat, rep)
  98. torch.cuda.synchronize()
  99. print('Elapsed:', time.time() - t)