IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
5.0KB

  1. from icosagon.data import Data, \
  2. _equal
  3. from icosagon.trainprep import prepare_training, \
  4. TrainValTest
  5. from icosagon.model import Model
  6. from icosagon.trainloop import TrainLoop
  7. import torch
  8. import pytest
  9. import pdb
  10. import time
  11. def test_train_loop_01():
  12. d = Data()
  13. d.add_node_type('Dummy', 10)
  14. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  15. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  16. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  17. m = Model(prep_d)
  18. loop = TrainLoop(m)
  19. assert loop.model == m
  20. assert loop.lr == 0.001
  21. assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
  22. assert loop.batch_size == 100
  23. def test_train_loop_02():
  24. d = Data()
  25. d.add_node_type('Dummy', 10)
  26. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  27. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  28. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  29. m = Model(prep_d)
  30. for prm in m.parameters():
  31. print(prm.shape, prm.is_leaf, prm.requires_grad)
  32. loop = TrainLoop(m)
  33. loop.run_epoch()
  34. for prm in m.parameters():
  35. print(prm.shape, prm.is_leaf, prm.requires_grad)
  36. def test_train_loop_03():
  37. # pdb.set_trace()
  38. if torch.cuda.device_count() == 0:
  39. pytest.skip('CUDA required for this test')
  40. adj_mat = torch.rand(10, 10).round()
  41. dev = torch.device('cuda:0')
  42. adj_mat = adj_mat.to(dev)
  43. d = Data()
  44. d.add_node_type('Dummy', 10)
  45. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  46. fam.add_relation_type('Dummy Rel', adj_mat)
  47. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  48. # pdb.set_trace()
  49. m = Model(prep_d)
  50. m = m.to(dev)
  51. print(list(m.parameters()))
  52. for prm in m.parameters():
  53. assert prm.device == dev
  54. loop = TrainLoop(m)
  55. loop.run_epoch()
  56. def test_train_loop_04():
  57. adj_mat = torch.rand(10, 10).round()
  58. d = Data()
  59. d.add_node_type('Dummy', 10)
  60. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  61. fam.add_relation_type('Dummy Rel', adj_mat)
  62. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  63. m = Model(prep_d)
  64. old_values = []
  65. for prm in m.parameters():
  66. old_values.append(prm.clone().detach())
  67. loop = TrainLoop(m)
  68. loop.run_epoch()
  69. for i, prm in enumerate(m.parameters()):
  70. assert not prm.requires_grad or \
  71. not torch.all(_equal(prm, old_values[i]))
  72. def test_train_loop_05():
  73. adj_mat = torch.rand(10, 10).round().to_sparse()
  74. d = Data()
  75. d.add_node_type('Dummy', 10)
  76. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  77. fam.add_relation_type('Dummy Rel', adj_mat)
  78. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  79. m = Model(prep_d)
  80. old_values = []
  81. for prm in m.parameters():
  82. old_values.append(prm.clone().detach())
  83. loop = TrainLoop(m)
  84. loop.run_epoch()
  85. for i, prm in enumerate(m.parameters()):
  86. assert not prm.requires_grad or \
  87. not torch.all(_equal(prm, old_values[i]))
  88. def test_timing_01():
  89. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  90. rep = torch.eye(2000).requires_grad_(True)
  91. t = time.time()
  92. for _ in range(1300):
  93. _ = torch.sparse.mm(adj_mat, rep)
  94. print('Elapsed:', time.time() - t)
  95. def test_timing_02():
  96. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  97. adj_mat_batch = [adj_mat.view(1, 2000, 2000)] * 1300
  98. adj_mat_batch = torch.cat(adj_mat_batch)
  99. rep = torch.eye(2000).requires_grad_(True)
  100. t = time.time()
  101. res = torch.matmul(adj_mat_batch, rep)
  102. print('Elapsed:', time.time() - t)
  103. print('res.shape:', res.shape)
  104. def test_timing_03():
  105. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
  106. adj_mat_batch = [adj_mat.view(1, 2000, 2000).to_sparse()] * 1300
  107. adj_mat_batch = torch.cat(adj_mat_batch)
  108. rep = torch.eye(2000).requires_grad_(True)
  109. rep_batch = [rep.view(1, 2000, 2000)] * 1300
  110. rep_batch = torch.cat(rep_batch)
  111. t = time.time()
  112. with pytest.raises(RuntimeError):
  113. _ = torch.bmm(adj_mat_batch, rep)
  114. print('Elapsed:', time.time() - t)
  115. def test_timing_04():
  116. adj_mat = (torch.rand(2000, 2000) < .0001).to(torch.float32).to_sparse()
  117. rep = torch.eye(2000).requires_grad_(True)
  118. t = time.time()
  119. for _ in range(1300):
  120. _ = torch.sparse.mm(adj_mat, rep)
  121. print('Elapsed:', time.time() - t)
  122. def test_timing_05():
  123. if torch.cuda.device_count() == 0:
  124. pytest.skip('Test requires CUDA')
  125. dev = torch.device('cuda:0')
  126. adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse().to(dev)
  127. rep = torch.eye(2000).requires_grad_(True).to(dev)
  128. t = time.time()
  129. for _ in range(1300):
  130. _ = torch.sparse.mm(adj_mat, rep)
  131. torch.cuda.synchronize()
  132. print('Elapsed:', time.time() - t)