IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Use if constexpr to make stable_sort_impl() more compact.

master
Stanislaw Adaszewski 3 years ago
parent
commit
b439a46fc3
2 changed files with 31 additions and 45 deletions
  1. +1
    -1
      src/torch_stablesort/setup.py
  2. +30
    -44
      src/torch_stablesort/torch_stablesort.cpp

+ 1
- 1
src/torch_stablesort/setup.py View File

@@ -5,5 +5,5 @@ setup(name='torch_stablesort',
py_modules=['torch_stablesort'],
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp',
['torch_stablesort.cpp'],
extra_compile_args=['-fopenmp', '-ggdb'])],
extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})

+ 30
- 44
src/torch_stablesort/torch_stablesort.cpp View File

@@ -6,12 +6,11 @@
#include "dispatch.h"
template<typename T>
template<bool descending, typename T>
struct stable_sort_impl {
std::vector<torch::Tensor> operator()(
torch::Tensor input,
int dim,
bool descending,
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
) const {
@@ -56,50 +55,27 @@ struct stable_sort_impl {
auto py = static_cast<int64_t*>(y.data_ptr());
if (descending) {
#pragma omp parallel for
for (decltype(n_rows) i = 0; i < n_rows; i++) {
std::vector<int64_t> indices(n_cols);
for (decltype(n_cols) k = 0; k < n_cols; k++) {
indices[k] = k;
}
std::stable_sort(std::begin(indices), std::end(indices),
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
auto va = pin[i * in_outer_stride + a * in_inner_stride];
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
return (vb < va);
});
for (decltype(n_cols) k = 0; k < n_cols; k++) {
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
px[i * x_outer_stride + k * x_inner_stride] =
pin[i * in_outer_stride + indices[k] * in_inner_stride];
}
#pragma omp parallel for
for (decltype(n_rows) i = 0; i < n_rows; i++) {
std::vector<int64_t> indices(n_cols);
for (decltype(n_cols) k = 0; k < n_cols; k++) {
indices[k] = k;
}
} else {
#pragma omp parallel for
for (decltype(n_rows) i = 0; i < n_rows; i++) {
std::vector<int64_t> indices(n_cols);
for (decltype(n_cols) k = 0; k < n_cols; k++) {
indices[k] = k;
}
std::stable_sort(std::begin(indices), std::end(indices),
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
auto va = pin[i * in_outer_stride + a * in_inner_stride];
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
std::stable_sort(std::begin(indices), std::end(indices),
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
auto va = pin[i * in_outer_stride + a * in_inner_stride];
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
if constexpr(descending)
return (vb < va);
else
return (va < vb);
});
});
for (decltype(n_cols) k = 0; k < n_cols; k++) {
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
px[i * x_outer_stride + k * x_inner_stride] =
pin[i * in_outer_stride + indices[k] * in_inner_stride];
}
for (decltype(n_cols) k = 0; k < n_cols; k++) {
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
px[i * x_outer_stride + k * x_inner_stride] =
pin[i * in_outer_stride + indices[k] * in_inner_stride];
}
}
@@ -122,14 +98,24 @@ struct stable_sort_impl {
}
};
template <typename T>
struct stable_sort_impl_desc: stable_sort_impl<true, T> {};
template <typename T>
struct stable_sort_impl_asc: stable_sort_impl<false, T> {};
std::vector<torch::Tensor> stable_sort(
torch::Tensor input,
int dim = -1,
bool descending = false,
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
return dispatch<stable_sort_impl, std::vector<torch::Tensor>>(
input, dim, descending, out);
if (descending)
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>(
input, dim, out);
else
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>(
input, dim, out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {


Loading…
Cancel
Save