|
- import torch
- import torch_stablesort_cpp
-
-
- class StableSort(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, dim=-1, descending=False, out=None):
- values, indices = \
- torch_stablesort_cpp.stable_sort(input, dim, descending, out)
- ctx.save_for_backward(input, indices)
- return values, indices
-
- @staticmethod
- def backward(ctx, grad_values, grad_indices):
- input, indices = ctx.saved_variables
- res = torch.empty_like(grad_values)
- res[indices] = grad_values + grad_indices
- return res
|