#include #include #include #include #include "dispatch.h" template struct stable_sort_impl { std::vector operator()( torch::Tensor input, int dim, bool descending, torch::optional> out ) const { if (input.is_sparse()) throw std::runtime_error("Sparse tensors are not supported"); if (input.device().type() != torch::DeviceType::CPU) throw std::runtime_error("Only CPU tensors are supported"); if (out != torch::nullopt) throw std::runtime_error("out argument is not supported"); auto in = (dim != -1) ? torch::transpose(input, dim, -1) : input; auto in_sizes = in.sizes(); // std::cout << "in_sizes: " << in_sizes << std::endl; in = in.view({ -1, in.size(-1) }).contiguous(); auto in_outer_stride = in.stride(-2); auto in_inner_stride = in.stride(-1); auto pin = static_cast(in.data_ptr()); auto x = in.clone(); auto x_outer_stride = x.stride(-2); auto x_inner_stride = x.stride(-1); auto n_cols = x.size(1); auto n_rows = x.size(0); auto px = static_cast(x.data_ptr()); auto y = torch::empty({ n_rows, n_cols }, torch::TensorOptions().dtype(torch::kInt64)); auto y_outer_stride = y.stride(-2); auto y_inner_stride = y.stride(-1); auto py = static_cast(y.data_ptr()); if (descending) { #pragma omp parallel for for (decltype(n_rows) i = 0; i < n_rows; i++) { std::vector indices(n_cols); for (decltype(n_cols) k = 0; k < n_cols; k++) { indices[k] = k; } std::stable_sort(std::begin(indices), std::end(indices), [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { auto va = pin[i * in_outer_stride + a * in_inner_stride]; auto vb = pin[i * in_outer_stride + b * in_inner_stride]; return (vb < va); }); for (decltype(n_cols) k = 0; k < n_cols; k++) { py[i * y_outer_stride + k * y_inner_stride] = indices[k]; px[i * x_outer_stride + k * x_inner_stride] = pin[i * in_outer_stride + indices[k] * in_inner_stride]; } } } else { #pragma omp parallel for for (decltype(n_rows) i = 0; i < n_rows; i++) { std::vector indices(n_cols); for (decltype(n_cols) k = 0; k < n_cols; k++) { indices[k] = k; } std::stable_sort(std::begin(indices), std::end(indices), [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { auto va = pin[i * in_outer_stride + a * in_inner_stride]; auto vb = pin[i * in_outer_stride + b * in_inner_stride]; return (va < vb); }); for (decltype(n_cols) k = 0; k < n_cols; k++) { py[i * y_outer_stride + k * y_inner_stride] = indices[k]; px[i * x_outer_stride + k * x_inner_stride] = pin[i * in_outer_stride + indices[k] * in_inner_stride]; } } } // std::cout << "Here" << std::endl; x = x.view(in_sizes); y = y.view(in_sizes); x = (dim == -1) ? x : torch::transpose(x, dim, -1).contiguous(); y = (dim == -1) ? y : torch::transpose(y, dim, -1).contiguous(); // std::cout << "Here 2" << std::endl; return { x, y }; } }; std::vector stable_sort( torch::Tensor input, int dim = -1, bool descending = false, torch::optional> out = torch::nullopt) { return dispatch>( input, dim, descending, out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("stable_sort", &stable_sort, "Stable sort", py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false, py::arg("out") = nullptr); }