|
|
@@ -6,12 +6,11 @@ |
|
|
|
|
|
|
|
#include "dispatch.h"
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
template<bool descending, typename T>
|
|
|
|
struct stable_sort_impl {
|
|
|
|
std::vector<torch::Tensor> operator()(
|
|
|
|
torch::Tensor input,
|
|
|
|
int dim,
|
|
|
|
bool descending,
|
|
|
|
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
|
|
|
|
) const {
|
|
|
|
|
|
|
@@ -56,50 +55,27 @@ struct stable_sort_impl { |
|
|
|
|
|
|
|
auto py = static_cast<int64_t*>(y.data_ptr());
|
|
|
|
|
|
|
|
if (descending) {
|
|
|
|
#pragma omp parallel for
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
|
|
|
|
std::vector<int64_t> indices(n_cols);
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
indices[k] = k;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::stable_sort(std::begin(indices), std::end(indices),
|
|
|
|
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
|
|
|
|
auto va = pin[i * in_outer_stride + a * in_inner_stride];
|
|
|
|
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
|
|
|
|
return (vb < va);
|
|
|
|
});
|
|
|
|
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
|
|
|
|
px[i * x_outer_stride + k * x_inner_stride] =
|
|
|
|
pin[i * in_outer_stride + indices[k] * in_inner_stride];
|
|
|
|
}
|
|
|
|
#pragma omp parallel for
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
std::vector<int64_t> indices(n_cols);
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
indices[k] = k;
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
#pragma omp parallel for
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
|
|
|
|
std::vector<int64_t> indices(n_cols);
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
indices[k] = k;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::stable_sort(std::begin(indices), std::end(indices),
|
|
|
|
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
|
|
|
|
auto va = pin[i * in_outer_stride + a * in_inner_stride];
|
|
|
|
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
|
|
|
|
std::stable_sort(std::begin(indices), std::end(indices),
|
|
|
|
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
|
|
|
|
auto va = pin[i * in_outer_stride + a * in_inner_stride];
|
|
|
|
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
|
|
|
|
if constexpr(descending)
|
|
|
|
return (vb < va);
|
|
|
|
else
|
|
|
|
return (va < vb);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
|
|
|
|
px[i * x_outer_stride + k * x_inner_stride] =
|
|
|
|
pin[i * in_outer_stride + indices[k] * in_inner_stride];
|
|
|
|
}
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
|
|
|
|
px[i * x_outer_stride + k * x_inner_stride] =
|
|
|
|
pin[i * in_outer_stride + indices[k] * in_inner_stride];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
@@ -122,14 +98,24 @@ struct stable_sort_impl { |
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
struct stable_sort_impl_desc: stable_sort_impl<true, T> {};
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
struct stable_sort_impl_asc: stable_sort_impl<false, T> {};
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> stable_sort(
|
|
|
|
torch::Tensor input,
|
|
|
|
int dim = -1,
|
|
|
|
bool descending = false,
|
|
|
|
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
|
|
|
|
|
|
|
|
return dispatch<stable_sort_impl, std::vector<torch::Tensor>>(
|
|
|
|
input, dim, descending, out);
|
|
|
|
if (descending)
|
|
|
|
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>(
|
|
|
|
input, dim, out);
|
|
|
|
else
|
|
|
|
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>(
|
|
|
|
input, dim, out);
|
|
|
|
}
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|