IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
1.0KB

  1. import torch
  2. import torch_stablesort_cpp
  3. class StableSort(torch.autograd.Function):
  4. @staticmethod
  5. def forward(ctx, input, dim=-1, descending=False, out=None):
  6. values, indices = \
  7. torch_stablesort_cpp.stable_sort(input, dim, descending, out)
  8. ctx.save_for_backward(input, indices, torch.tensor(dim))
  9. return values, indices.detach()
  10. @staticmethod
  11. def backward(ctx, grad_values, grad_indices):
  12. input, indices, dim = ctx.saved_variables
  13. # print('backward(), grad_indices:', grad_indices, 'indices:', indices,
  14. # 'grad_values:', grad_values)
  15. res = torch.gather(grad_values, dim, indices)
  16. # res = torch.empty_like(grad_values)
  17. # print('here')
  18. # res = res.view(-1, res.size(-1))
  19. # indices = indices.view(-1, indices.size(-1))
  20. # torch.repeat_interleave(torch.arange(indices.size(0))
  21. # res[indices] = grad_values # + grad_indices
  22. # print('here 2')
  23. return res