diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 75f056e..454d3c1 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -20,16 +20,19 @@ def test_get_true_classes_01(): print('true_classes:', true_classes) assert torch.all(true_classes == torch.tensor([ + [1, 3], [1, 3], [4, -1], [0, 1], + [0, 1], + [2, 4], [2, 4], [1, -1] ])) def test_get_true_classes_02(): - adj_mat = torch.rand(2000, 2000).round().to_sparse() + adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse() t = time.time() true_classes = get_true_classes(adj_mat)