import torch import torch_stablesort_cpp class StableSort(torch.autograd.Function): @staticmethod def forward(ctx, input, dim=-1, descending=False, out=None): values, indices = \ torch_stablesort_cpp.stable_sort(input, dim, descending, out) ctx.save_for_backward(input, indices, torch.tensor(dim)) return values, indices.detach() @staticmethod def backward(ctx, grad_values, grad_indices): input, indices, dim = ctx.saved_variables # print('backward(), grad_indices:', grad_indices, 'indices:', indices, # 'grad_values:', grad_values) res = torch.gather(grad_values, dim, indices) # res = torch.empty_like(grad_values) # print('here') # res = res.view(-1, res.size(-1)) # indices = indices.view(-1, indices.size(-1)) # torch.repeat_interleave(torch.arange(indices.size(0)) # res[indices] = grad_values # + grad_indices # print('here 2') return res