| 
				
				
				
				 | 
			
			 | 
			@@ -0,0 +1,18 @@ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import torch
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import torch_stablesort_cpp
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class StableSort(torch.autograd.Function):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    @staticmethod
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(ctx, input, dim=-1, descending=False, out=None):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        values, indices = \
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            torch_stablesort_cpp.stable_sort(input, dim, descending, out)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        ctx.save_for_backward(input, indices)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return values, indices
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    @staticmethod
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def backward(ctx, grad_values, grad_indices):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        input, indices = ctx.saved_variables
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        res = torch.empty_like(grad_values)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        res[indices] = grad_values + grad_indices
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return res
 |