IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.3KB

  1. from icosagon.convolve import GraphConv, \
  2. DropoutGraphConvActivation, \
  3. MultiDGCA
  4. import torch
  5. def _test_graph_conv_01(use_sparse: bool):
  6. adj_mat = torch.rand((10, 20))
  7. adj_mat[adj_mat < .5] = 0
  8. adj_mat = torch.ceil(adj_mat)
  9. node_reprs = torch.eye(20)
  10. graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
  11. if use_sparse else adj_mat)
  12. graph_conv.weight = torch.eye(20)
  13. res = graph_conv(node_reprs)
  14. assert torch.all(res == adj_mat)
  15. def _test_graph_conv_02(use_sparse: bool):
  16. adj_mat = torch.rand((10, 20))
  17. adj_mat[adj_mat < .5] = 0
  18. adj_mat = torch.ceil(adj_mat)
  19. node_reprs = torch.eye(20)
  20. graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
  21. if use_sparse else adj_mat)
  22. graph_conv.weight = torch.eye(20) * 2
  23. res = graph_conv(node_reprs)
  24. assert torch.all(res == adj_mat * 2)
  25. def _test_graph_conv_03(use_sparse: bool):
  26. adj_mat = torch.tensor([
  27. [1, 0, 1, 0, 1, 0], # [1, 0, 0]
  28. [1, 0, 1, 0, 0, 1], # [1, 0, 0]
  29. [1, 1, 0, 1, 0, 0], # [0, 1, 0]
  30. [0, 0, 0, 1, 0, 1], # [0, 1, 0]
  31. [1, 1, 1, 1, 1, 1], # [0, 0, 1]
  32. [0, 0, 0, 1, 1, 1] # [0, 0, 1]
  33. ], dtype=torch.float32)
  34. expect = torch.tensor([
  35. [1, 1, 1],
  36. [1, 1, 1],
  37. [2, 1, 0],
  38. [0, 1, 1],
  39. [2, 2, 2],
  40. [0, 1, 2]
  41. ], dtype=torch.float32)
  42. node_reprs = torch.eye(6)
  43. graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
  44. if use_sparse else adj_mat)
  45. graph_conv.weight = torch.tensor([
  46. [1, 0, 0],
  47. [1, 0, 0],
  48. [0, 1, 0],
  49. [0, 1, 0],
  50. [0, 0, 1],
  51. [0, 0, 1]
  52. ], dtype=torch.float32)
  53. res = graph_conv(node_reprs)
  54. assert torch.all(res == expect)
  55. def test_graph_conv_dense_01():
  56. _test_graph_conv_01(use_sparse=False)
  57. def test_graph_conv_dense_02():
  58. _test_graph_conv_02(use_sparse=False)
  59. def test_graph_conv_dense_03():
  60. _test_graph_conv_03(use_sparse=False)
  61. def test_graph_conv_sparse_01():
  62. _test_graph_conv_01(use_sparse=True)
  63. def test_graph_conv_sparse_02():
  64. _test_graph_conv_02(use_sparse=True)
  65. def test_graph_conv_sparse_03():
  66. _test_graph_conv_03(use_sparse=True)