|
|
@@ -7,12 +7,22 @@ class StableSort(torch.autograd.Function): |
|
|
|
def forward(ctx, input, dim=-1, descending=False, out=None):
|
|
|
|
values, indices = \
|
|
|
|
torch_stablesort_cpp.stable_sort(input, dim, descending, out)
|
|
|
|
ctx.save_for_backward(input, indices)
|
|
|
|
return values, indices
|
|
|
|
ctx.save_for_backward(input, indices, torch.tensor(dim))
|
|
|
|
return values, indices.detach()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_values, grad_indices):
|
|
|
|
input, indices = ctx.saved_variables
|
|
|
|
res = torch.empty_like(grad_values)
|
|
|
|
res[indices] = grad_values + grad_indices
|
|
|
|
input, indices, dim = ctx.saved_variables
|
|
|
|
# print('backward(), grad_indices:', grad_indices, 'indices:', indices,
|
|
|
|
# 'grad_values:', grad_values)
|
|
|
|
|
|
|
|
res = torch.gather(grad_values, dim, indices)
|
|
|
|
|
|
|
|
# res = torch.empty_like(grad_values)
|
|
|
|
# print('here')
|
|
|
|
# res = res.view(-1, res.size(-1))
|
|
|
|
# indices = indices.view(-1, indices.size(-1))
|
|
|
|
# torch.repeat_interleave(torch.arange(indices.size(0))
|
|
|
|
# res[indices] = grad_values # + grad_indices
|
|
|
|
# print('here 2')
|
|
|
|
return res
|