|
12345678910111213141516171819202122232425262728 |
- import torch
- import torch_stablesort_cpp
-
-
- class StableSort(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, dim=-1, descending=False, out=None):
- values, indices = \
- torch_stablesort_cpp.stable_sort(input, dim, descending, out)
- ctx.save_for_backward(input, indices, torch.tensor(dim))
- return values, indices.detach()
-
- @staticmethod
- def backward(ctx, grad_values, grad_indices):
- input, indices, dim = ctx.saved_variables
- # print('backward(), grad_indices:', grad_indices, 'indices:', indices,
- # 'grad_values:', grad_values)
-
- res = torch.gather(grad_values, dim, indices)
-
- # res = torch.empty_like(grad_values)
- # print('here')
- # res = res.view(-1, res.size(-1))
- # indices = indices.view(-1, indices.size(-1))
- # torch.repeat_interleave(torch.arange(indices.size(0))
- # res[indices] = grad_values # + grad_indices
- # print('here 2')
- return res
|