IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

191 行
5.1KB

  1. from icosagon.convolve import GraphConv, \
  2. DropoutGraphConvActivation
  3. import torch
  4. from icosagon.dropout import dropout
  5. def _test_graph_conv_01(use_sparse: bool):
  6. adj_mat = torch.rand((10, 20))
  7. adj_mat[adj_mat < .5] = 0
  8. adj_mat = torch.ceil(adj_mat)
  9. node_reprs = torch.eye(20)
  10. graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
  11. if use_sparse else adj_mat)
  12. graph_conv.weight = torch.eye(20)
  13. res = graph_conv(node_reprs)
  14. assert torch.all(res == adj_mat)
  15. def _test_graph_conv_02(use_sparse: bool):
  16. adj_mat = torch.rand((10, 20))
  17. adj_mat[adj_mat < .5] = 0
  18. adj_mat = torch.ceil(adj_mat)
  19. node_reprs = torch.eye(20)
  20. graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
  21. if use_sparse else adj_mat)
  22. graph_conv.weight = torch.eye(20) * 2
  23. res = graph_conv(node_reprs)
  24. assert torch.all(res == adj_mat * 2)
  25. def _test_graph_conv_03(use_sparse: bool):
  26. adj_mat = torch.tensor([
  27. [1, 0, 1, 0, 1, 0], # [1, 0, 0]
  28. [1, 0, 1, 0, 0, 1], # [1, 0, 0]
  29. [1, 1, 0, 1, 0, 0], # [0, 1, 0]
  30. [0, 0, 0, 1, 0, 1], # [0, 1, 0]
  31. [1, 1, 1, 1, 1, 1], # [0, 0, 1]
  32. [0, 0, 0, 1, 1, 1] # [0, 0, 1]
  33. ], dtype=torch.float32)
  34. expect = torch.tensor([
  35. [1, 1, 1],
  36. [1, 1, 1],
  37. [2, 1, 0],
  38. [0, 1, 1],
  39. [2, 2, 2],
  40. [0, 1, 2]
  41. ], dtype=torch.float32)
  42. node_reprs = torch.eye(6)
  43. graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
  44. if use_sparse else adj_mat)
  45. graph_conv.weight = torch.tensor([
  46. [1, 0, 0],
  47. [1, 0, 0],
  48. [0, 1, 0],
  49. [0, 1, 0],
  50. [0, 0, 1],
  51. [0, 0, 1]
  52. ], dtype=torch.float32)
  53. res = graph_conv(node_reprs)
  54. assert torch.all(res == expect)
  55. def test_graph_conv_dense_01():
  56. _test_graph_conv_01(use_sparse=False)
  57. def test_graph_conv_dense_02():
  58. _test_graph_conv_02(use_sparse=False)
  59. def test_graph_conv_dense_03():
  60. _test_graph_conv_03(use_sparse=False)
  61. def test_graph_conv_sparse_01():
  62. _test_graph_conv_01(use_sparse=True)
  63. def test_graph_conv_sparse_02():
  64. _test_graph_conv_02(use_sparse=True)
  65. def test_graph_conv_sparse_03():
  66. _test_graph_conv_03(use_sparse=True)
  67. def _test_dropout_graph_conv_activation_01(use_sparse: bool):
  68. adj_mat = torch.rand((10, 20))
  69. adj_mat[adj_mat < .5] = 0
  70. adj_mat = torch.ceil(adj_mat)
  71. node_reprs = torch.eye(20)
  72. conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
  73. if use_sparse else adj_mat, keep_prob=1.,
  74. activation=lambda x: x)
  75. conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
  76. if use_sparse else adj_mat)
  77. conv_2.weight = conv_1.graph_conv.weight
  78. res_1 = conv_1(node_reprs)
  79. res_2 = conv_2(node_reprs)
  80. print('res_1:', res_1.detach().cpu().numpy())
  81. print('res_2:', res_2.detach().cpu().numpy())
  82. assert torch.all(res_1 == res_2)
  83. def _test_dropout_graph_conv_activation_02(use_sparse: bool):
  84. adj_mat = torch.rand((10, 20))
  85. adj_mat[adj_mat < .5] = 0
  86. adj_mat = torch.ceil(adj_mat)
  87. node_reprs = torch.eye(20)
  88. conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
  89. if use_sparse else adj_mat, keep_prob=1.,
  90. activation=lambda x: x * 2)
  91. conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
  92. if use_sparse else adj_mat)
  93. conv_2.weight = conv_1.graph_conv.weight
  94. res_1 = conv_1(node_reprs)
  95. res_2 = conv_2(node_reprs)
  96. print('res_1:', res_1.detach().cpu().numpy())
  97. print('res_2:', res_2.detach().cpu().numpy())
  98. assert torch.all(res_1 == res_2 * 2)
  99. def _test_dropout_graph_conv_activation_03(use_sparse: bool):
  100. adj_mat = torch.rand((10, 20))
  101. adj_mat[adj_mat < .5] = 0
  102. adj_mat = torch.ceil(adj_mat)
  103. node_reprs = torch.eye(20)
  104. conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
  105. if use_sparse else adj_mat, keep_prob=.5,
  106. activation=lambda x: x)
  107. conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
  108. if use_sparse else adj_mat)
  109. conv_2.weight = conv_1.graph_conv.weight
  110. torch.random.manual_seed(0)
  111. res_1 = conv_1(node_reprs)
  112. torch.random.manual_seed(0)
  113. res_2 = conv_2(dropout(node_reprs, 0.5))
  114. print('res_1:', res_1.detach().cpu().numpy())
  115. print('res_2:', res_2.detach().cpu().numpy())
  116. assert torch.all(res_1 == res_2)
  117. def test_dropout_graph_conv_activation_dense_01():
  118. _test_dropout_graph_conv_activation_01(False)
  119. def test_dropout_graph_conv_activation_sparse_01():
  120. _test_dropout_graph_conv_activation_01(True)
  121. def test_dropout_graph_conv_activation_dense_02():
  122. _test_dropout_graph_conv_activation_02(False)
  123. def test_dropout_graph_conv_activation_sparse_02():
  124. _test_dropout_graph_conv_activation_02(True)
  125. def test_dropout_graph_conv_activation_dense_03():
  126. _test_dropout_graph_conv_activation_03(False)
  127. def test_dropout_graph_conv_activation_sparse_03():
  128. _test_dropout_graph_conv_activation_03(True)