IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.2KB

  1. #pragma once
  2. #include <torch/extension.h>
  3. #include <vector>
  4. #include <tuple>
  5. #include "dispatch.h"
  6. template<bool descending, typename T>
  7. struct stable_sort_impl {
  8. std::vector<torch::Tensor> operator()(
  9. torch::Tensor input,
  10. int dim,
  11. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
  12. ) const {
  13. if (input.is_sparse())
  14. throw std::runtime_error("Sparse tensors are not supported");
  15. if (input.device().type() != torch::DeviceType::CPU)
  16. throw std::runtime_error("Only CPU tensors are supported");
  17. if (out != torch::nullopt)
  18. throw std::runtime_error("out argument is not supported");
  19. auto in = (dim != -1) ?
  20. torch::transpose(input, dim, -1) :
  21. input;
  22. auto in_sizes = in.sizes();
  23. // std::cout << "in_sizes: " << in_sizes << std::endl;
  24. in = in.view({ -1, in.size(-1) }).contiguous();
  25. auto in_outer_stride = in.stride(-2);
  26. auto in_inner_stride = in.stride(-1);
  27. auto pin = static_cast<T*>(in.data_ptr());
  28. auto x = in.clone();
  29. auto x_outer_stride = x.stride(-2);
  30. auto x_inner_stride = x.stride(-1);
  31. auto n_cols = x.size(1);
  32. auto n_rows = x.size(0);
  33. auto px = static_cast<T*>(x.data_ptr());
  34. auto y = torch::empty({ n_rows, n_cols },
  35. torch::TensorOptions().dtype(torch::kInt64));
  36. auto y_outer_stride = y.stride(-2);
  37. auto y_inner_stride = y.stride(-1);
  38. auto py = static_cast<int64_t*>(y.data_ptr());
  39. #pragma omp parallel for
  40. for (decltype(n_rows) i = 0; i < n_rows; i++) {
  41. std::vector<int64_t> indices(n_cols);
  42. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  43. indices[k] = k;
  44. }
  45. std::stable_sort(std::begin(indices), std::end(indices),
  46. [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
  47. auto va = pin[i * in_outer_stride + a * in_inner_stride];
  48. auto vb = pin[i * in_outer_stride + b * in_inner_stride];
  49. if constexpr(descending)
  50. return (vb < va);
  51. else
  52. return (va < vb);
  53. });
  54. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  55. py[i * y_outer_stride + k * y_inner_stride] = indices[k];
  56. px[i * x_outer_stride + k * x_inner_stride] =
  57. pin[i * in_outer_stride + indices[k] * in_inner_stride];
  58. }
  59. }
  60. // std::cout << "Here" << std::endl;
  61. x = x.view(in_sizes);
  62. y = y.view(in_sizes);
  63. x = (dim == -1) ?
  64. x :
  65. torch::transpose(x, dim, -1).contiguous();
  66. y = (dim == -1) ?
  67. y :
  68. torch::transpose(y, dim, -1).contiguous();
  69. // std::cout << "Here 2" << std::endl;
  70. return { x, y };
  71. }
  72. };
  73. template <typename T>
  74. struct stable_sort_impl_desc: stable_sort_impl<true, T> {};
  75. template <typename T>
  76. struct stable_sort_impl_asc: stable_sort_impl<false, T> {};
  77. std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input,
  78. int dim,
  79. bool descending,
  80. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
  81. if (descending)
  82. return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>(
  83. input, dim, out);
  84. else
  85. return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>(
  86. input, dim, out);
  87. }