IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Fix test_convolve.

master
Stanislaw Adaszewski pirms 4 gadiem
vecāks
revīzija
aa541e0b42
1 mainītis faili ar 4 papildinājumiem un 4 dzēšanām
  1. +4
    -4
      tests/icosagon/test_convolve.py

+ 4
- 4
tests/icosagon/test_convolve.py Parādīt failu

@@ -13,7 +13,7 @@ def _test_graph_conv_01(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20)
graph_conv.weight = torch.nn.Parameter(torch.eye(20))
res = graph_conv(node_reprs)
assert torch.all(res == adj_mat)
@@ -28,7 +28,7 @@ def _test_graph_conv_02(use_sparse: bool):
graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.eye(20) * 2
graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2)
res = graph_conv(node_reprs)
assert torch.all(res == adj_mat * 2)
@@ -57,14 +57,14 @@ def _test_graph_conv_03(use_sparse: bool):
graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
if use_sparse else adj_mat)
graph_conv.weight = torch.tensor([
graph_conv.weight = torch.nn.Parameter(torch.tensor([
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1]
], dtype=torch.float32)
], dtype=torch.float32))
res = graph_conv(node_reprs)
assert torch.all(res == expect)


Notiek ielāde…
Atcelt
Saglabāt