| 
				
				
					
				
				
				 | 
			
			 | 
			@@ -112,3 +112,16 @@ def test_timing_04(): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    for _ in range(1300):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        _ = torch.sparse.mm(adj_mat, rep)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print('Elapsed:', time.time() - t)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def test_timing_05():
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    if torch.cuda.device_count() == 0:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        pytest.skip('Test requires CUDA')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    dev = torch.device('cuda:0')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse().to(dev)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    rep = torch.eye(2000).requires_grad_(True).to(dev)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    t = time.time()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    for _ in range(1300):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        _ = torch.sparse.mm(adj_mat, rep)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    torch.cuda.synchronize()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print('Elapsed:', time.time() - t)
 |