|
- #include <torch/extension.h>
-
- #include "dispatch.h"
- #include "torch_stablesort_cpu.h"
-
- template<bool descending, typename T>
- struct stable_sort_impl {
- std::vector<torch::Tensor> operator()(
- torch::Tensor input,
- int dim,
- torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
- ) const {
-
- if (input.is_sparse())
- throw std::runtime_error("Sparse tensors are not supported");
-
- if (input.device().type() != torch::DeviceType::CPU)
- throw std::runtime_error("Only CPU tensors are supported");
-
- if (out != torch::nullopt)
- throw std::runtime_error("out argument is not supported");
-
- auto in = (dim != -1) ?
- torch::transpose(input, dim, -1) :
- input;
-
- auto in_sizes = in.sizes();
-
- // std::cout << "in_sizes: " << in_sizes << std::endl;
-
- in = in.view({ -1, in.size(-1) }).contiguous();
-
- auto in_outer_stride = in.stride(-2);
- auto in_inner_stride = in.stride(-1);
-
- auto pin = static_cast<T*>(in.data_ptr());
-
- auto x = in.clone();
-
- auto x_outer_stride = x.stride(-2);
- auto x_inner_stride = x.stride(-1);
-
- auto n_cols = x.size(1);
- auto n_rows = x.size(0);
- auto px = static_cast<T*>(x.data_ptr());
-
- auto y = torch::empty({ n_rows, n_cols },
- torch::TensorOptions().dtype(torch::kInt64));
-
- auto y_outer_stride = y.stride(-2);
- auto y_inner_stride = y.stride(-1);
-
- auto py = static_cast<int64_t*>(y.data_ptr());
-
- #pragma omp parallel for
- for (decltype(n_rows) i = 0; i < n_rows; i++) {
- std::vector<int64_t> indices(n_cols);
- for (decltype(n_cols) k = 0; k < n_cols; k++) {
- indices[k] = k;
- }
-
- std::stable_sort(std::begin(indices), std::end(indices),
- [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
- auto va = pin[i * in_outer_stride + a * in_inner_stride];
- auto vb = pin[i * in_outer_stride + b * in_inner_stride];
- if constexpr(descending)
- return (vb < va);
- else
- return (va < vb);
- });
-
- for (decltype(n_cols) k = 0; k < n_cols; k++) {
- py[i * y_outer_stride + k * y_inner_stride] = indices[k];
- px[i * x_outer_stride + k * x_inner_stride] =
- pin[i * in_outer_stride + indices[k] * in_inner_stride];
- }
- }
-
- // std::cout << "Here" << std::endl;
-
- x = x.view(in_sizes);
- y = y.view(in_sizes);
-
- x = (dim == -1) ?
- x :
- torch::transpose(x, dim, -1).contiguous();
-
- y = (dim == -1) ?
- y :
- torch::transpose(y, dim, -1).contiguous();
-
- // std::cout << "Here 2" << std::endl;
-
- return { x, y };
- }
- };
-
- template <typename T>
- struct stable_sort_impl_desc: stable_sort_impl<true, T> {};
-
- template <typename T>
- struct stable_sort_impl_asc: stable_sort_impl<false, T> {};
-
- std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input,
- int dim,
- bool descending,
- torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
-
- if (descending)
- return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>(
- input, dim, out);
- else
- return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>(
- input, dim, out);
- }
|