|
@@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool): |
|
|
|
|
|
|
|
|
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
|
|
|
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
|
|
|
if use_sparse else adj_mat)
|
|
|
if use_sparse else adj_mat)
|
|
|
graph_conv.weight = torch.eye(20)
|
|
|
|
|
|
|
|
|
graph_conv.weight = torch.nn.Parameter(torch.eye(20))
|
|
|
|
|
|
|
|
|
res = graph_conv(node_reprs)
|
|
|
res = graph_conv(node_reprs)
|
|
|
assert torch.all(res == adj_mat)
|
|
|
assert torch.all(res == adj_mat)
|
|
@@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool): |
|
|
|
|
|
|
|
|
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
|
|
|
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
|
|
|
if use_sparse else adj_mat)
|
|
|
if use_sparse else adj_mat)
|
|
|
graph_conv.weight = torch.eye(20) * 2
|
|
|
|
|
|
|
|
|
graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2)
|
|
|
|
|
|
|
|
|
res = graph_conv(node_reprs)
|
|
|
res = graph_conv(node_reprs)
|
|
|
assert torch.all(res == adj_mat * 2)
|
|
|
assert torch.all(res == adj_mat * 2)
|
|
@@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool): |
|
|
|
|
|
|
|
|
graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
|
|
|
graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
|
|
|
if use_sparse else adj_mat)
|
|
|
if use_sparse else adj_mat)
|
|
|
graph_conv.weight = torch.tensor([
|
|
|
|
|
|
|
|
|
graph_conv.weight = torch.nn.Parameter(torch.tensor([
|
|
|
[1, 0, 0],
|
|
|
[1, 0, 0],
|
|
|
[1, 0, 0],
|
|
|
[1, 0, 0],
|
|
|
[0, 1, 0],
|
|
|
[0, 1, 0],
|
|
|
[0, 1, 0],
|
|
|
[0, 1, 0],
|
|
|
[0, 0, 1],
|
|
|
[0, 0, 1],
|
|
|
[0, 0, 1]
|
|
|
[0, 0, 1]
|
|
|
], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
], dtype=torch.float32))
|
|
|
|
|
|
|
|
|
res = graph_conv(node_reprs)
|
|
|
res = graph_conv(node_reprs)
|
|
|
assert torch.all(res == expect)
|
|
|
assert torch.all(res == expect)
|
|
|