IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

19 lines
593B

  1. import torch
  2. import torch_stablesort_cpp
  3. class StableSort(torch.autograd.Function):
  4. @staticmethod
  5. def forward(ctx, input, dim=-1, descending=False, out=None):
  6. values, indices = \
  7. torch_stablesort_cpp.stable_sort(input, dim, descending, out)
  8. ctx.save_for_backward(input, indices)
  9. return values, indices
  10. @staticmethod
  11. def backward(ctx, grad_values, grad_indices):
  12. input, indices = ctx.saved_variables
  13. res = torch.empty_like(grad_values)
  14. res[indices] = grad_values + grad_indices
  15. return res