From aa541e0b42d9f27300336d60171c083f91d3cf4b Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 17 Jun 2020 12:43:04 +0200 Subject: [PATCH] Fix test_convolve. --- tests/icosagon/test_convolve.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/icosagon/test_convolve.py b/tests/icosagon/test_convolve.py index a916a89..5802800 100644 --- a/tests/icosagon/test_convolve.py +++ b/tests/icosagon/test_convolve.py @@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool): graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) - graph_conv.weight = torch.eye(20) + graph_conv.weight = torch.nn.Parameter(torch.eye(20)) res = graph_conv(node_reprs) assert torch.all(res == adj_mat) @@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool): graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) - graph_conv.weight = torch.eye(20) * 2 + graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2) res = graph_conv(node_reprs) assert torch.all(res == adj_mat * 2) @@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool): graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ if use_sparse else adj_mat) - graph_conv.weight = torch.tensor([ + graph_conv.weight = torch.nn.Parameter(torch.tensor([ [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1] - ], dtype=torch.float32) + ], dtype=torch.float32)) res = graph_conv(node_reprs) assert torch.all(res == expect)