|
@@ -20,16 +20,19 @@ def test_get_true_classes_01(): |
|
|
print('true_classes:', true_classes)
|
|
|
print('true_classes:', true_classes)
|
|
|
|
|
|
|
|
|
assert torch.all(true_classes == torch.tensor([
|
|
|
assert torch.all(true_classes == torch.tensor([
|
|
|
|
|
|
[1, 3],
|
|
|
[1, 3],
|
|
|
[1, 3],
|
|
|
[4, -1],
|
|
|
[4, -1],
|
|
|
[0, 1],
|
|
|
[0, 1],
|
|
|
|
|
|
[0, 1],
|
|
|
|
|
|
[2, 4],
|
|
|
[2, 4],
|
|
|
[2, 4],
|
|
|
[1, -1]
|
|
|
[1, -1]
|
|
|
]))
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_true_classes_02():
|
|
|
def test_get_true_classes_02():
|
|
|
adj_mat = torch.rand(2000, 2000).round().to_sparse()
|
|
|
|
|
|
|
|
|
adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse()
|
|
|
|
|
|
|
|
|
t = time.time()
|
|
|
t = time.time()
|
|
|
true_classes = get_true_classes(adj_mat)
|
|
|
true_classes = get_true_classes(adj_mat)
|
|
|