|  |  | @@ -78,6 +78,8 @@ def test_clear_adjacency_matrix_except_rows_03(): | 
		
	
		
			
			|  |  |  | def test_clear_adjacency_matrix_except_rows_04(): | 
		
	
		
			
			|  |  |  | adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('adj_mat.to_sparse():', adj_mat.to_sparse()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | t = time.time() | 
		
	
		
			
			|  |  |  | res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | 
		
	
		
			
			|  |  |  | print('_sparse_diag_cat() took:', time.time() - t) | 
		
	
	
		
			
				|  |  | @@ -93,3 +95,29 @@ def test_clear_adjacency_matrix_except_rows_04(): | 
		
	
		
			
			|  |  |  | truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert _equal(res, truth).all() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_clear_adjacency_matrix_except_rows_05(): | 
		
	
		
			
			|  |  |  | if torch.cuda.device_count() == 0: | 
		
	
		
			
			|  |  |  | pytest.skip('Test requires CUDA') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | device = torch.device('cuda:0') | 
		
	
		
			
			|  |  |  | adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('adj_mat.to_sparse():', adj_mat.to_sparse()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | t = time.time() | 
		
	
		
			
			|  |  |  | res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | 
		
	
		
			
			|  |  |  | print('_sparse_diag_cat() took:', time.time() - t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | rows = torch.tensor(list(range(512)), device=device) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | t = time.time() | 
		
	
		
			
			|  |  |  | res = _clear_adjacency_matrix_except_rows(res, rows, | 
		
	
		
			
			|  |  |  | 2000, 1300) | 
		
	
		
			
			|  |  |  | print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | adj_mat[512:] = torch.zeros(2000) | 
		
	
		
			
			|  |  |  | truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert _equal(res, truth).all() |