IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.4KB

  1. #pragma once
  2. #include <thrust/sort.h>
  3. #include <thrust/device_ptr.h>
  4. #include <thrust/execution_policy.h>
  5. #include "dispatch.h"
  6. #include "torch_stablesort_cuda.h"
  7. template<bool descending, typename T>
  8. struct stable_sort_impl_cuda {
  9. std::vector<torch::Tensor> operator()(
  10. torch::Tensor input,
  11. int dim,
  12. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
  13. ) const {
  14. if (input.is_sparse())
  15. throw std::runtime_error("Sparse tensors are not supported");
  16. if (input.device().type() != torch::DeviceType::CUDA)
  17. throw std::runtime_error("Only CUDA tensors are supported");
  18. if (out != torch::nullopt)
  19. throw std::runtime_error("out argument is not supported");
  20. auto values = input.clone();
  21. if (dim != -1)
  22. values = torch::transpose(values, dim, -1);
  23. auto orig_sizes = values.sizes();
  24. values = values.view({ -1, values.size(-1) }).contiguous();
  25. auto n_cols = values.size(1);
  26. auto n_rows = values.size(0);
  27. auto n = n_rows * n_cols;
  28. assert(values.stride(-2) == n_cols);
  29. assert(values.stride(-1) == 1);
  30. auto values_ptr = values.data_ptr<T>();
  31. auto indices = torch::arange(0, n, 1, torch::TensorOptions()
  32. .dtype(torch::kInt64)
  33. .device(values.device())).view({ n_rows, n_cols });
  34. assert(indices.stride(-2) == n_cols);
  35. assert(indices.stride(-1) == 1);
  36. auto indices_ptr = indices.data_ptr<int64_t>();
  37. auto ind_beg = thrust::device_pointer_cast(indices_ptr);
  38. auto val_beg = thrust::device_pointer_cast(values_ptr);
  39. if (descending)
  40. thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater<T>());
  41. else
  42. thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg);
  43. thrust::device_vector<int64_t> segments(n);
  44. thrust::constant_iterator<int64_t> n_cols_iter(n_cols);
  45. thrust::transform(thrust::device,
  46. ind_beg, ind_beg + n, n_cols_iter,
  47. segments.begin(), thrust::divides<int64_t>());
  48. thrust::stable_sort_by_key(thrust::device, segments.begin(),
  49. segments.end(), val_beg);
  50. thrust::transform(thrust::device,
  51. ind_beg, ind_beg + n, n_cols_iter,
  52. segments.begin(), thrust::divides<int64_t>());
  53. thrust::stable_sort_by_key(thrust::device, segments.begin(),
  54. segments.end(), ind_beg);
  55. thrust::transform(thrust::device, ind_beg, ind_beg + n,
  56. n_cols_iter, ind_beg, thrust::modulus<int64_t>());
  57. cudaDeviceSynchronize();
  58. values = values.view(orig_sizes);
  59. indices = indices.view(orig_sizes);
  60. if (dim != -1)
  61. values = torch::transpose(values, dim, -1).contiguous();
  62. if (dim != -1)
  63. indices = torch::transpose(indices, dim, -1).contiguous();
  64. return { values, indices };
  65. }
  66. };
  67. template <typename T>
  68. struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {};
  69. template <typename T>
  70. struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {};
  71. std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input,
  72. int dim,
  73. bool descending,
  74. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
  75. if (descending)
  76. return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>(
  77. input, dim, out);
  78. else
  79. return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>(
  80. input, dim, out);
  81. }