IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.0KB

  1. #pragma once
  2. #include <torch/extension.h>
  3. #include <thrust/sort.h>
  4. #include <thrust/device_ptr.h>
  5. #include <thrust/execution_policy.h>
  6. #include <vector>
  7. #include <tuple>
  8. #include "dispatch.h"
  9. template<bool descending, typename T>
  10. struct stable_sort_impl_cuda {
  11. std::vector<torch::Tensor> operator()(
  12. torch::Tensor input,
  13. int dim,
  14. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
  15. ) const {
  16. if (input.is_sparse())
  17. throw std::runtime_error("Sparse tensors are not supported");
  18. if (input.device().type() != torch::DeviceType::CUDA)
  19. throw std::runtime_error("Only CUDA tensors are supported");
  20. if (out != torch::nullopt)
  21. throw std::runtime_error("out argument is not supported");
  22. auto x = input.clone();
  23. if (dim != -1)
  24. x = torch::transpose(x, dim, -1);
  25. auto x_sizes = x.sizes();
  26. x = x.view({ -1, x.size(-1) }).contiguous();
  27. auto x_outer_stride = x.stride(-2);
  28. auto x_inner_stride = x.stride(-1);
  29. auto n_cols = x.size(1);
  30. auto n_rows = x.size(0);
  31. auto px = x.data_ptr<T>();
  32. assert(x_inner_stride == 1);
  33. auto y = torch::repeat_interleave(
  34. torch::arange(0, n_cols, 1, torch::TensorOptions()
  35. .dtype(torch::kInt32)
  36. .device(x.device())),
  37. torch::ones(n_rows, torch::TensorOptions()
  38. .dtype(torch::kInt32)
  39. .device(x.device()))
  40. );
  41. auto y_outer_stride = y.stride(-2);
  42. auto y_inner_stride = y.stride(-1);
  43. auto py = y.data_ptr<int32_t>();
  44. assert(y_inner_stride == 1);
  45. for (decltype(n_rows) i = 0; i < n_rows; i++) {
  46. auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
  47. auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
  48. auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
  49. n_cols * x_inner_stride);
  50. if constexpr(descending)
  51. thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg,
  52. thrust::greater<T>());
  53. else
  54. thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg);
  55. }
  56. x = x.view(x_sizes);
  57. y = y.view(x_sizes);
  58. x = (dim == -1) ?
  59. x :
  60. torch::transpose(x, dim, -1).contiguous();
  61. y = (dim == -1) ?
  62. y :
  63. torch::transpose(y, dim, -1).contiguous();
  64. return { x, y };
  65. }
  66. };
  67. template <typename T>
  68. struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {};
  69. template <typename T>
  70. struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {};
  71. std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input,
  72. int dim,
  73. bool descending,
  74. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
  75. if (descending)
  76. return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>(
  77. input, dim, out);
  78. else
  79. return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>(
  80. input, dim, out);
  81. }