IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fastconv.py 2.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from icosagon.fastconv import _sparse_diag_cat, \
  2. _cat, \
  3. FastGraphConv
  4. from icosagon.data import _equal
  5. import torch
  6. import pdb
  7. import time
  8. def test_sparse_diag_cat_01():
  9. matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
  10. ground_truth = torch.zeros(35, 70)
  11. ground_truth[0:5, 0:10] = matrices[0]
  12. ground_truth[5:10, 10:20] = matrices[1]
  13. ground_truth[10:15, 20:30] = matrices[2]
  14. ground_truth[15:20, 30:40] = matrices[3]
  15. ground_truth[20:25, 40:50] = matrices[4]
  16. ground_truth[25:30, 50:60] = matrices[5]
  17. ground_truth[30:35, 60:70] = matrices[6]
  18. res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
  19. res = res.to_dense()
  20. assert torch.all(res == ground_truth)
  21. def test_sparse_diag_cat_02():
  22. x = [ torch.rand(5, 10).round() for _ in range(7) ]
  23. a = [ m.to_sparse() for m in x ]
  24. a = _sparse_diag_cat(a)
  25. b = torch.rand(70, 64)
  26. res = torch.sparse.mm(a, b)
  27. ground_truth = torch.zeros(35, 64)
  28. ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
  29. ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
  30. ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
  31. ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
  32. ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
  33. ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
  34. ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
  35. assert torch.all(res == ground_truth)
  36. def test_cat_01():
  37. matrices = [ torch.rand(5, 10) for _ in range(7) ]
  38. res = _cat(matrices)
  39. assert res.shape == (35, 10)
  40. assert not res.is_sparse
  41. ground_truth = torch.zeros(35, 10)
  42. for i in range(7):
  43. ground_truth[i*5:(i+1)*5, :] = matrices[i]
  44. assert torch.all(res == ground_truth)
  45. def test_cat_02():
  46. matrices = [ torch.rand(5, 10) for _ in range(7) ]
  47. ground_truth = torch.zeros(35, 10)
  48. for i in range(7):
  49. ground_truth[i*5:(i+1)*5, :] = matrices[i]
  50. res = _cat([ m.to_sparse() for m in matrices ])
  51. assert res.shape == (35, 10)
  52. assert res.is_sparse
  53. assert torch.all(res.to_dense() == ground_truth)
  54. def test_fast_graph_conv_01():
  55. # pdb.set_trace()
  56. adj_mats = [ torch.rand(10, 15).round().to_sparse() \
  57. for _ in range(23) ]
  58. fgc = FastGraphConv(32, 64, adj_mats)
  59. in_repr = torch.rand(15, 32)
  60. _ = fgc(in_repr)
  61. def test_fast_graph_conv_02():
  62. t = time.time()
  63. m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  64. adj_mats = [ m for _ in range(1300) ]
  65. print('Generating adj_mats took:', time.time() - t)
  66. t = time.time()
  67. fgc = FastGraphConv(32, 64, adj_mats)
  68. print('FGC constructor took:', time.time() - t)
  69. in_repr = torch.rand(2000, 32)
  70. for _ in range(3):
  71. t = time.time()
  72. _ = fgc(in_repr)
  73. print('FGC forward pass took:', time.time() - t)