From b439a46fc33fd4b206aa7f1a71aeaba8404248ea Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 28 Aug 2020 22:31:44 +0200 Subject: [PATCH] Use if constexpr to make stable_sort_impl() more compact. --- src/torch_stablesort/setup.py | 2 +- src/torch_stablesort/torch_stablesort.cpp | 74 +++++++++-------------- 2 files changed, 31 insertions(+), 45 deletions(-) diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index a45a47f..cf5a856 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -5,5 +5,5 @@ setup(name='torch_stablesort', py_modules=['torch_stablesort'], ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', ['torch_stablesort.cpp'], - extra_compile_args=['-fopenmp', '-ggdb'])], + extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort.cpp b/src/torch_stablesort/torch_stablesort.cpp index 94b3e13..4bc3d6e 100644 --- a/src/torch_stablesort/torch_stablesort.cpp +++ b/src/torch_stablesort/torch_stablesort.cpp @@ -6,12 +6,11 @@ #include "dispatch.h" -template +template struct stable_sort_impl { std::vector operator()( torch::Tensor input, int dim, - bool descending, torch::optional> out ) const { @@ -56,50 +55,27 @@ struct stable_sort_impl { auto py = static_cast(y.data_ptr()); - if (descending) { - #pragma omp parallel for - for (decltype(n_rows) i = 0; i < n_rows; i++) { - - std::vector indices(n_cols); - for (decltype(n_cols) k = 0; k < n_cols; k++) { - indices[k] = k; - } - - std::stable_sort(std::begin(indices), std::end(indices), - [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { - auto va = pin[i * in_outer_stride + a * in_inner_stride]; - auto vb = pin[i * in_outer_stride + b * in_inner_stride]; - return (vb < va); - }); - - for (decltype(n_cols) k = 0; k < n_cols; k++) { - py[i * y_outer_stride + k * y_inner_stride] = indices[k]; - px[i * x_outer_stride + k * x_inner_stride] = - pin[i * in_outer_stride + indices[k] * in_inner_stride]; - } + #pragma omp parallel for + for (decltype(n_rows) i = 0; i < n_rows; i++) { + std::vector indices(n_cols); + for (decltype(n_cols) k = 0; k < n_cols; k++) { + indices[k] = k; } - } else { - #pragma omp parallel for - for (decltype(n_rows) i = 0; i < n_rows; i++) { - - std::vector indices(n_cols); - for (decltype(n_cols) k = 0; k < n_cols; k++) { - indices[k] = k; - } - - std::stable_sort(std::begin(indices), std::end(indices), - [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { - auto va = pin[i * in_outer_stride + a * in_inner_stride]; - auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + std::stable_sort(std::begin(indices), std::end(indices), + [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { + auto va = pin[i * in_outer_stride + a * in_inner_stride]; + auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + if constexpr(descending) + return (vb < va); + else return (va < vb); - }); + }); - for (decltype(n_cols) k = 0; k < n_cols; k++) { - py[i * y_outer_stride + k * y_inner_stride] = indices[k]; - px[i * x_outer_stride + k * x_inner_stride] = - pin[i * in_outer_stride + indices[k] * in_inner_stride]; - } + for (decltype(n_cols) k = 0; k < n_cols; k++) { + py[i * y_outer_stride + k * y_inner_stride] = indices[k]; + px[i * x_outer_stride + k * x_inner_stride] = + pin[i * in_outer_stride + indices[k] * in_inner_stride]; } } @@ -122,14 +98,24 @@ struct stable_sort_impl { } }; +template +struct stable_sort_impl_desc: stable_sort_impl {}; + +template +struct stable_sort_impl_asc: stable_sort_impl {}; + std::vector stable_sort( torch::Tensor input, int dim = -1, bool descending = false, torch::optional> out = torch::nullopt) { - return dispatch>( - input, dim, descending, out); + if (descending) + return dispatch>( + input, dim, out); + else + return dispatch>( + input, dim, out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {