IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.5KB

  1. #pragma once
  2. #include <thrust/sort.h>
  3. #include <thrust/device_ptr.h>
  4. #include <thrust/execution_policy.h>
  5. #include "dispatch.h"
  6. #include "torch_stablesort_cuda.h"
  7. template<bool descending, typename T>
  8. struct stable_sort_impl_cuda {
  9. std::vector<torch::Tensor> operator()(
  10. torch::Tensor input,
  11. int dim,
  12. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
  13. ) const {
  14. if (input.is_sparse())
  15. throw std::runtime_error("Sparse tensors are not supported");
  16. if (input.device().type() != torch::DeviceType::CUDA)
  17. throw std::runtime_error("Only CUDA tensors are supported");
  18. if (out != torch::nullopt)
  19. throw std::runtime_error("out argument is not supported");
  20. auto x = input.clone();
  21. if (dim != -1)
  22. x = torch::transpose(x, dim, -1);
  23. auto x_sizes = x.sizes();
  24. x = x.view({ -1, x.size(-1) }).contiguous();
  25. auto x_outer_stride = x.stride(-2);
  26. auto x_inner_stride = x.stride(-1);
  27. auto n_cols = x.size(1);
  28. auto n_rows = x.size(0);
  29. auto px = x.data_ptr<T>();
  30. assert(x_inner_stride == 1);
  31. auto y = torch::repeat_interleave(
  32. torch::arange(0, n_cols, 1, torch::TensorOptions()
  33. .dtype(torch::kInt64)
  34. .device(x.device())).view({ 1, -1 }),
  35. n_rows,
  36. 0 /* dim */
  37. );
  38. auto y_outer_stride = y.stride(-2);
  39. auto y_inner_stride = y.stride(-1);
  40. auto py = y.data_ptr<int64_t>();
  41. assert(y_inner_stride == 1);
  42. #define NUM_STREAMS 16
  43. cudaStream_t streams[NUM_STREAMS];
  44. for(int i = 0; i < NUM_STREAMS; i++)
  45. assert(cudaStreamCreate(&streams[i]) == cudaSuccess);
  46. thrust::host_vector<int64_t> row_indices(n_rows);
  47. thrust::sequence(row_indices.begin(), row_indices.end());
  48. thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(),
  49. [&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) {
  50. auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
  51. auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
  52. auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
  53. n_cols * x_inner_stride);
  54. if (descending)
  55. thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg,
  56. thrust::greater<T>());
  57. else
  58. thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg);
  59. });
  60. cudaDeviceSynchronize();
  61. for(int i = 0; i < NUM_STREAMS; i++)
  62. assert(cudaStreamDestroy(streams[i]) == cudaSuccess);
  63. x = x.view(x_sizes);
  64. y = y.view(x_sizes);
  65. x = (dim == -1) ?
  66. x :
  67. torch::transpose(x, dim, -1).contiguous();
  68. y = (dim == -1) ?
  69. y :
  70. torch::transpose(y, dim, -1).contiguous();
  71. return { x, y };
  72. }
  73. };
  74. template <typename T>
  75. struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {};
  76. template <typename T>
  77. struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {};
  78. std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input,
  79. int dim,
  80. bool descending,
  81. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
  82. if (descending)
  83. return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>(
  84. input, dim, out);
  85. else
  86. return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>(
  87. input, dim, out);
  88. }