import torch import torch_stablesort_cpp class StableSort(torch.autograd.Function): @staticmethod def forward(ctx, input, dim=-1, descending=False, out=None): values, indices = \ torch_stablesort_cpp.stable_sort(input, dim, descending, out) ctx.save_for_backward(input, indices) return values, indices @staticmethod def backward(ctx, grad_values, grad_indices): input, indices = ctx.saved_variables res = torch.empty_like(grad_values) res[indices] = grad_values + grad_indices return res