| @@ -20,16 +20,19 @@ def test_get_true_classes_01(): | |||||
| print('true_classes:', true_classes) | print('true_classes:', true_classes) | ||||
| assert torch.all(true_classes == torch.tensor([ | assert torch.all(true_classes == torch.tensor([ | ||||
| [1, 3], | |||||
| [1, 3], | [1, 3], | ||||
| [4, -1], | [4, -1], | ||||
| [0, 1], | [0, 1], | ||||
| [0, 1], | |||||
| [2, 4], | |||||
| [2, 4], | [2, 4], | ||||
| [1, -1] | [1, -1] | ||||
| ])) | ])) | ||||
| def test_get_true_classes_02(): | def test_get_true_classes_02(): | ||||
| adj_mat = torch.rand(2000, 2000).round().to_sparse() | |||||
| adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse() | |||||
| t = time.time() | t = time.time() | ||||
| true_classes = get_true_classes(adj_mat) | true_classes = get_true_classes(adj_mat) | ||||