| @@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool): | |||||
| graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | ||||
| if use_sparse else adj_mat) | if use_sparse else adj_mat) | ||||
| graph_conv.weight = torch.eye(20) | |||||
| graph_conv.weight = torch.nn.Parameter(torch.eye(20)) | |||||
| res = graph_conv(node_reprs) | res = graph_conv(node_reprs) | ||||
| assert torch.all(res == adj_mat) | assert torch.all(res == adj_mat) | ||||
| @@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool): | |||||
| graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ | ||||
| if use_sparse else adj_mat) | if use_sparse else adj_mat) | ||||
| graph_conv.weight = torch.eye(20) * 2 | |||||
| graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2) | |||||
| res = graph_conv(node_reprs) | res = graph_conv(node_reprs) | ||||
| assert torch.all(res == adj_mat * 2) | assert torch.all(res == adj_mat * 2) | ||||
| @@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool): | |||||
| graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ | graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ | ||||
| if use_sparse else adj_mat) | if use_sparse else adj_mat) | ||||
| graph_conv.weight = torch.tensor([ | |||||
| graph_conv.weight = torch.nn.Parameter(torch.tensor([ | |||||
| [1, 0, 0], | [1, 0, 0], | ||||
| [1, 0, 0], | [1, 0, 0], | ||||
| [0, 1, 0], | [0, 1, 0], | ||||
| [0, 1, 0], | [0, 1, 0], | ||||
| [0, 0, 1], | [0, 0, 1], | ||||
| [0, 0, 1] | [0, 0, 1] | ||||
| ], dtype=torch.float32) | |||||
| ], dtype=torch.float32)) | |||||
| res = graph_conv(node_reprs) | res = graph_conv(node_reprs) | ||||
| assert torch.all(res == expect) | assert torch.all(res == expect) | ||||