IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Fix test_convolve.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
aa541e0b42
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. +4
    -4
      tests/icosagon/test_convolve.py

+ 4
- 4
tests/icosagon/test_convolve.py 查看文件

@@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20)
graph_conv.weight = torch.nn.Parameter(torch.eye(20))
res = graph_conv(node_reprs)
assert torch.all(res == adj_mat)
@@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20) * 2
graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2)
res = graph_conv(node_reprs)
assert torch.all(res == adj_mat * 2)
@@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool):
graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.tensor([
graph_conv.weight = torch.nn.Parameter(torch.tensor([
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1]
], dtype=torch.float32)
], dtype=torch.float32))
res = graph_conv(node_reprs)
assert torch.all(res == expect)


正在加载...
取消
保存