|  |  | @@ -15,3 +15,22 @@ def test_sparse_diag_cat_01(): | 
		
	
		
			
			|  |  |  | res = _sparse_diag_cat([ m.to_sparse() for m in matrices ]) | 
		
	
		
			
			|  |  |  | res = res.to_dense() | 
		
	
		
			
			|  |  |  | assert torch.all(res == ground_truth) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_sparse_diag_cat_02(): | 
		
	
		
			
			|  |  |  | x = [ torch.rand(5, 10).round() for _ in range(7) ] | 
		
	
		
			
			|  |  |  | a = [ m.to_sparse() for m in x ] | 
		
	
		
			
			|  |  |  | a = _sparse_diag_cat(a) | 
		
	
		
			
			|  |  |  | b = torch.rand(70, 64) | 
		
	
		
			
			|  |  |  | res = torch.sparse.mm(a, b) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | ground_truth = torch.zeros(35, 64) | 
		
	
		
			
			|  |  |  | ground_truth[0:5, :] = torch.mm(x[0], b[0:10]) | 
		
	
		
			
			|  |  |  | ground_truth[5:10, :] = torch.mm(x[1], b[10:20]) | 
		
	
		
			
			|  |  |  | ground_truth[10:15, :] = torch.mm(x[2], b[20:30]) | 
		
	
		
			
			|  |  |  | ground_truth[15:20, :] = torch.mm(x[3], b[30:40]) | 
		
	
		
			
			|  |  |  | ground_truth[20:25, :] = torch.mm(x[4], b[40:50]) | 
		
	
		
			
			|  |  |  | ground_truth[25:30, :] = torch.mm(x[5], b[50:60]) | 
		
	
		
			
			|  |  |  | ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(res == ground_truth) |