diff --git a/src/torch_stablesort/dispatch.h b/src/torch_stablesort/dispatch.h index a7dfb59..ef91b46 100644 --- a/src/torch_stablesort/dispatch.h +++ b/src/torch_stablesort/dispatch.h @@ -1,5 +1,6 @@ #pragma once +#include #include template class F, typename R, typename... Ts> diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index 8001773..ab9412c 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -3,8 +3,11 @@ from torch.utils import cpp_extension setup(name='torch_stablesort', py_modules=['torch_stablesort'], - ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', - ['torch_stablesort.cpp'], - extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', - '-fopenmp', '-ggdb', '-std=c++1z'])], - cmdclass={'build_ext': cpp_extension.BuildExtension}) + ext_modules=[ cpp_extension.CUDAExtension( 'torch_stablesort_cpp', + ['torch_stablesort.cpp', 'torch_stablesort_cpu.cpp', 'torch_stablesort_cuda.cu'], + extra_compile_args={ + 'cxx': ['-fopenmp', '-ggdb', '-std=c++1z'], + 'nvcc': [ '-I/pstore/home/adaszews/scratch/thrust', + '-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc', '-std=c++14'] + } ) ], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort_cpu.cpp b/src/torch_stablesort/torch_stablesort_cpu.cpp new file mode 100644 index 0000000..12cef54 --- /dev/null +++ b/src/torch_stablesort/torch_stablesort_cpu.cpp @@ -0,0 +1,115 @@ +#include + +#include "dispatch.h" +#include "torch_stablesort_cpu.h" + +template +struct stable_sort_impl { + std::vector operator()( + torch::Tensor input, + int dim, + torch::optional> out + ) const { + + if (input.is_sparse()) + throw std::runtime_error("Sparse tensors are not supported"); + + if (input.device().type() != torch::DeviceType::CPU) + throw std::runtime_error("Only CPU tensors are supported"); + + if (out != torch::nullopt) + throw std::runtime_error("out argument is not supported"); + + auto in = (dim != -1) ? + torch::transpose(input, dim, -1) : + input; + + auto in_sizes = in.sizes(); + + // std::cout << "in_sizes: " << in_sizes << std::endl; + + in = in.view({ -1, in.size(-1) }).contiguous(); + + auto in_outer_stride = in.stride(-2); + auto in_inner_stride = in.stride(-1); + + auto pin = static_cast(in.data_ptr()); + + auto x = in.clone(); + + auto x_outer_stride = x.stride(-2); + auto x_inner_stride = x.stride(-1); + + auto n_cols = x.size(1); + auto n_rows = x.size(0); + auto px = static_cast(x.data_ptr()); + + auto y = torch::empty({ n_rows, n_cols }, + torch::TensorOptions().dtype(torch::kInt64)); + + auto y_outer_stride = y.stride(-2); + auto y_inner_stride = y.stride(-1); + + auto py = static_cast(y.data_ptr()); + + #pragma omp parallel for + for (decltype(n_rows) i = 0; i < n_rows; i++) { + std::vector indices(n_cols); + for (decltype(n_cols) k = 0; k < n_cols; k++) { + indices[k] = k; + } + + std::stable_sort(std::begin(indices), std::end(indices), + [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { + auto va = pin[i * in_outer_stride + a * in_inner_stride]; + auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + if constexpr(descending) + return (vb < va); + else + return (va < vb); + }); + + for (decltype(n_cols) k = 0; k < n_cols; k++) { + py[i * y_outer_stride + k * y_inner_stride] = indices[k]; + px[i * x_outer_stride + k * x_inner_stride] = + pin[i * in_outer_stride + indices[k] * in_inner_stride]; + } + } + + // std::cout << "Here" << std::endl; + + x = x.view(in_sizes); + y = y.view(in_sizes); + + x = (dim == -1) ? + x : + torch::transpose(x, dim, -1).contiguous(); + + y = (dim == -1) ? + y : + torch::transpose(y, dim, -1).contiguous(); + + // std::cout << "Here 2" << std::endl; + + return { x, y }; + } +}; + +template +struct stable_sort_impl_desc: stable_sort_impl {}; + +template +struct stable_sort_impl_asc: stable_sort_impl {}; + +std::vector dispatch_cpu(torch::Tensor input, + int dim, + bool descending, + torch::optional> out) { + + if (descending) + return dispatch>( + input, dim, out); + else + return dispatch>( + input, dim, out); +} diff --git a/src/torch_stablesort/torch_stablesort_cpu.h b/src/torch_stablesort/torch_stablesort_cpu.h index b2bbce1..7b54251 100644 --- a/src/torch_stablesort/torch_stablesort_cpu.h +++ b/src/torch_stablesort/torch_stablesort_cpu.h @@ -5,115 +5,7 @@ #include #include -#include "dispatch.h" - -template -struct stable_sort_impl { - std::vector operator()( - torch::Tensor input, - int dim, - torch::optional> out - ) const { - - if (input.is_sparse()) - throw std::runtime_error("Sparse tensors are not supported"); - - if (input.device().type() != torch::DeviceType::CPU) - throw std::runtime_error("Only CPU tensors are supported"); - - if (out != torch::nullopt) - throw std::runtime_error("out argument is not supported"); - - auto in = (dim != -1) ? - torch::transpose(input, dim, -1) : - input; - - auto in_sizes = in.sizes(); - - // std::cout << "in_sizes: " << in_sizes << std::endl; - - in = in.view({ -1, in.size(-1) }).contiguous(); - - auto in_outer_stride = in.stride(-2); - auto in_inner_stride = in.stride(-1); - - auto pin = static_cast(in.data_ptr()); - - auto x = in.clone(); - - auto x_outer_stride = x.stride(-2); - auto x_inner_stride = x.stride(-1); - - auto n_cols = x.size(1); - auto n_rows = x.size(0); - auto px = static_cast(x.data_ptr()); - - auto y = torch::empty({ n_rows, n_cols }, - torch::TensorOptions().dtype(torch::kInt64)); - - auto y_outer_stride = y.stride(-2); - auto y_inner_stride = y.stride(-1); - - auto py = static_cast(y.data_ptr()); - - #pragma omp parallel for - for (decltype(n_rows) i = 0; i < n_rows; i++) { - std::vector indices(n_cols); - for (decltype(n_cols) k = 0; k < n_cols; k++) { - indices[k] = k; - } - - std::stable_sort(std::begin(indices), std::end(indices), - [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { - auto va = pin[i * in_outer_stride + a * in_inner_stride]; - auto vb = pin[i * in_outer_stride + b * in_inner_stride]; - if constexpr(descending) - return (vb < va); - else - return (va < vb); - }); - - for (decltype(n_cols) k = 0; k < n_cols; k++) { - py[i * y_outer_stride + k * y_inner_stride] = indices[k]; - px[i * x_outer_stride + k * x_inner_stride] = - pin[i * in_outer_stride + indices[k] * in_inner_stride]; - } - } - - // std::cout << "Here" << std::endl; - - x = x.view(in_sizes); - y = y.view(in_sizes); - - x = (dim == -1) ? - x : - torch::transpose(x, dim, -1).contiguous(); - - y = (dim == -1) ? - y : - torch::transpose(y, dim, -1).contiguous(); - - // std::cout << "Here 2" << std::endl; - - return { x, y }; - } -}; - -template -struct stable_sort_impl_desc: stable_sort_impl {}; - -template -struct stable_sort_impl_asc: stable_sort_impl {}; - std::vector dispatch_cpu(torch::Tensor input, int dim, bool descending, - torch::optional> out) { - - if (descending) - return dispatch>( - input, dim, out); - else - return dispatch>( - input, dim, out); -} + torch::optional> out); diff --git a/src/torch_stablesort/torch_stablesort_cuda.cu b/src/torch_stablesort/torch_stablesort_cuda.cu new file mode 100644 index 0000000..bafa9ee --- /dev/null +++ b/src/torch_stablesort/torch_stablesort_cuda.cu @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include + +#include "dispatch.h" +#include "torch_stablesort_cuda.h" + +template +struct stable_sort_impl_cuda { + std::vector operator()( + torch::Tensor input, + int dim, + torch::optional> out + ) const { + + if (input.is_sparse()) + throw std::runtime_error("Sparse tensors are not supported"); + + if (input.device().type() != torch::DeviceType::CUDA) + throw std::runtime_error("Only CUDA tensors are supported"); + + if (out != torch::nullopt) + throw std::runtime_error("out argument is not supported"); + + auto x = input.clone(); + + if (dim != -1) + x = torch::transpose(x, dim, -1); + + auto x_sizes = x.sizes(); + + x = x.view({ -1, x.size(-1) }).contiguous(); + + auto x_outer_stride = x.stride(-2); + auto x_inner_stride = x.stride(-1); + auto n_cols = x.size(1); + auto n_rows = x.size(0); + auto px = x.data_ptr(); + + assert(x_inner_stride == 1); + + auto y = torch::repeat_interleave( + torch::arange(0, n_cols, 1, torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device())), + torch::ones(n_rows, torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device())) + ); + + auto y_outer_stride = y.stride(-2); + auto y_inner_stride = y.stride(-1); + auto py = y.data_ptr(); + + assert(y_inner_stride == 1); + + for (decltype(n_rows) i = 0; i < n_rows; i++) { + auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); + + auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); + auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + + n_cols * x_inner_stride); + + if constexpr(descending) + thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, + thrust::greater()); + else + thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); + } + + x = x.view(x_sizes); + y = y.view(x_sizes); + + x = (dim == -1) ? + x : + torch::transpose(x, dim, -1).contiguous(); + + y = (dim == -1) ? + y : + torch::transpose(y, dim, -1).contiguous(); + + return { x, y }; + } +}; + +template +struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {}; + +template +struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {}; + +std::vector dispatch_cuda(torch::Tensor input, + int dim, + bool descending, + torch::optional> out) { + + if (descending) + return dispatch>( + input, dim, out); + else + return dispatch>( + input, dim, out); +} diff --git a/src/torch_stablesort/torch_stablesort_cuda.h b/src/torch_stablesort/torch_stablesort_cuda.h index 496429f..e5f10e4 100644 --- a/src/torch_stablesort/torch_stablesort_cuda.h +++ b/src/torch_stablesort/torch_stablesort_cuda.h @@ -2,108 +2,10 @@ #include -#include -#include -#include - #include #include -#include "dispatch.h" - -template -struct stable_sort_impl_cuda { - std::vector operator()( - torch::Tensor input, - int dim, - torch::optional> out - ) const { - - if (input.is_sparse()) - throw std::runtime_error("Sparse tensors are not supported"); - - if (input.device().type() != torch::DeviceType::CUDA) - throw std::runtime_error("Only CUDA tensors are supported"); - - if (out != torch::nullopt) - throw std::runtime_error("out argument is not supported"); - - auto x = input.clone(); - - if (dim != -1) - x = torch::transpose(x, dim, -1); - - auto x_sizes = x.sizes(); - - x = x.view({ -1, x.size(-1) }).contiguous(); - - auto x_outer_stride = x.stride(-2); - auto x_inner_stride = x.stride(-1); - auto n_cols = x.size(1); - auto n_rows = x.size(0); - auto px = x.data_ptr(); - - assert(x_inner_stride == 1); - - auto y = torch::repeat_interleave( - torch::arange(0, n_cols, 1, torch::TensorOptions() - .dtype(torch::kInt32) - .device(x.device())), - torch::ones(n_rows, torch::TensorOptions() - .dtype(torch::kInt32) - .device(x.device())) - ); - - auto y_outer_stride = y.stride(-2); - auto y_inner_stride = y.stride(-1); - auto py = y.data_ptr(); - - assert(y_inner_stride == 1); - - for (decltype(n_rows) i = 0; i < n_rows; i++) { - auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); - - auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); - auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + - n_cols * x_inner_stride); - - if constexpr(descending) - thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, - thrust::greater()); - else - thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); - } - - x = x.view(x_sizes); - y = y.view(x_sizes); - - x = (dim == -1) ? - x : - torch::transpose(x, dim, -1).contiguous(); - - y = (dim == -1) ? - y : - torch::transpose(y, dim, -1).contiguous(); - - return { x, y }; - } -}; - -template -struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {}; - -template -struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {}; - std::vector dispatch_cuda(torch::Tensor input, int dim, bool descending, - torch::optional> out) { - - if (descending) - return dispatch>( - input, dim, out); - else - return dispatch>( - input, dim, out); -} + torch::optional> out);