IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
4.0KB

  1. #include <torch/extension.h>
  2. #include <iostream>
  3. #include <vector>
  4. #include <algorithm>
  5. #include "dispatch.h"
  6. template<typename T>
  7. struct stable_sort_impl {
  8. std::vector<torch::Tensor> operator()(
  9. torch::Tensor input,
  10. int dim,
  11. bool descending,
  12. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
  13. ) const {
  14. if (input.is_sparse())
  15. throw std::runtime_error("Sparse tensors are not supported");
  16. if (input.device().type() != torch::DeviceType::CPU)
  17. throw std::runtime_error("Only CPU tensors are supported");
  18. if (out != torch::nullopt)
  19. throw std::runtime_error("out argument is not supported");
  20. auto in = (dim != -1) ?
  21. torch::transpose(input, dim, -1) :
  22. input;
  23. auto in_sizes = in.sizes();
  24. // std::cout << "in_sizes: " << in_sizes << std::endl;
  25. in = in.view({ -1, in.size(-1) }).contiguous();
  26. auto in_outer_stride = in.stride(-2);
  27. auto in_inner_stride = in.stride(-1);
  28. auto pin = static_cast<T*>(in.data_ptr());
  29. auto x = in.clone();
  30. auto x_outer_stride = x.stride(-2);
  31. auto x_inner_stride = x.stride(-1);
  32. auto n_cols = x.size(1);
  33. auto n_rows = x.size(0);
  34. auto px = static_cast<T*>(x.data_ptr());
  35. auto y = torch::empty({ n_rows, n_cols },
  36. torch::TensorOptions().dtype(torch::kInt64));
  37. auto y_outer_stride = y.stride(-2);
  38. auto y_inner_stride = y.stride(-1);
  39. auto py = static_cast<int64_t*>(y.data_ptr());
  40. if (descending) {
  41. #pragma omp parallel for
  42. for (decltype(n_rows) i = 0; i < n_rows; i++) {
  43. std::vector<int64_t> indices(n_cols);
  44. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  45. indices[k] = k;
  46. }
  47. std::stable_sort(std::begin(indices), std::end(indices),
  48. [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
  49. auto va = pin[i * in_outer_stride + a * in_inner_stride];
  50. auto vb = pin[i * in_outer_stride + b * in_inner_stride];
  51. return (vb < va);
  52. });
  53. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  54. py[i * y_outer_stride + k * y_inner_stride] = indices[k];
  55. px[i * x_outer_stride + k * x_inner_stride] =
  56. pin[i * in_outer_stride + indices[k] * in_inner_stride];
  57. }
  58. }
  59. } else {
  60. #pragma omp parallel for
  61. for (decltype(n_rows) i = 0; i < n_rows; i++) {
  62. std::vector<int64_t> indices(n_cols);
  63. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  64. indices[k] = k;
  65. }
  66. std::stable_sort(std::begin(indices), std::end(indices),
  67. [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
  68. auto va = pin[i * in_outer_stride + a * in_inner_stride];
  69. auto vb = pin[i * in_outer_stride + b * in_inner_stride];
  70. return (va < vb);
  71. });
  72. for (decltype(n_cols) k = 0; k < n_cols; k++) {
  73. py[i * y_outer_stride + k * y_inner_stride] = indices[k];
  74. px[i * x_outer_stride + k * x_inner_stride] =
  75. pin[i * in_outer_stride + indices[k] * in_inner_stride];
  76. }
  77. }
  78. }
  79. // std::cout << "Here" << std::endl;
  80. x = x.view(in_sizes);
  81. y = y.view(in_sizes);
  82. x = (dim == -1) ?
  83. x :
  84. torch::transpose(x, dim, -1).contiguous();
  85. y = (dim == -1) ?
  86. y :
  87. torch::transpose(y, dim, -1).contiguous();
  88. // std::cout << "Here 2" << std::endl;
  89. return { x, y };
  90. }
  91. };
  92. std::vector<torch::Tensor> stable_sort(
  93. torch::Tensor input,
  94. int dim = -1,
  95. bool descending = false,
  96. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
  97. return dispatch<stable_sort_impl, std::vector<torch::Tensor>>(
  98. input, dim, descending, out);
  99. }
  100. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  101. m.def("stable_sort", &stable_sort, "Stable sort",
  102. py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false,
  103. py::arg("out") = nullptr);
  104. }