IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Fix test_convolve.

master
Stanislaw Adaszewski 4 years ago
parent
commit
aa541e0b42
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      tests/icosagon/test_convolve.py

+ 4
- 4
tests/icosagon/test_convolve.py View File

@@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat) if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20)
graph_conv.weight = torch.nn.Parameter(torch.eye(20))
res = graph_conv(node_reprs) res = graph_conv(node_reprs)
assert torch.all(res == adj_mat) assert torch.all(res == adj_mat)
@@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat) if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20) * 2
graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2)
res = graph_conv(node_reprs) res = graph_conv(node_reprs)
assert torch.all(res == adj_mat * 2) assert torch.all(res == adj_mat * 2)
@@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool):
graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
if use_sparse else adj_mat) if use_sparse else adj_mat)
graph_conv.weight = torch.tensor([
graph_conv.weight = torch.nn.Parameter(torch.tensor([
[1, 0, 0], [1, 0, 0],
[1, 0, 0], [1, 0, 0],
[0, 1, 0], [0, 1, 0],
[0, 1, 0], [0, 1, 0],
[0, 0, 1], [0, 0, 1],
[0, 0, 1] [0, 0, 1]
], dtype=torch.float32)
], dtype=torch.float32))
res = graph_conv(node_reprs) res = graph_conv(node_reprs)
assert torch.all(res == expect) assert torch.all(res == expect)


Loading…
Cancel
Save