|
|
@@ -0,0 +1,18 @@ |
|
|
|
import torch
|
|
|
|
import torch_stablesort_cpp
|
|
|
|
|
|
|
|
|
|
|
|
class StableSort(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, input, dim=-1, descending=False, out=None):
|
|
|
|
values, indices = \
|
|
|
|
torch_stablesort_cpp.stable_sort(input, dim, descending, out)
|
|
|
|
ctx.save_for_backward(input, indices)
|
|
|
|
return values, indices
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_values, grad_indices):
|
|
|
|
input, indices = ctx.saved_variables
|
|
|
|
res = torch.empty_like(grad_values)
|
|
|
|
res[indices] = grad_values + grad_indices
|
|
|
|
return res
|