From f8f2901eec5d3d421aaf5569431335fb2e8bacd2 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Mon, 31 Aug 2020 20:40:55 +0200 Subject: [PATCH] Start working on CUDA implementation of torch_stablesort. --- src/torch_stablesort/dispatch.h | 2 + src/torch_stablesort/setup.py | 5 +- src/torch_stablesort/torch_stablesort.cpp | 115 ++---------------- src/torch_stablesort/torch_stablesort_cpu.h | 119 +++++++++++++++++++ src/torch_stablesort/torch_stablesort_cuda.h | 109 +++++++++++++++++ 5 files changed, 243 insertions(+), 107 deletions(-) create mode 100644 src/torch_stablesort/torch_stablesort_cpu.h create mode 100644 src/torch_stablesort/torch_stablesort_cuda.h diff --git a/src/torch_stablesort/dispatch.h b/src/torch_stablesort/dispatch.h index 1dc9b49..a7dfb59 100644 --- a/src/torch_stablesort/dispatch.h +++ b/src/torch_stablesort/dispatch.h @@ -1,3 +1,5 @@ +#pragma once + #include template class F, typename R, typename... Ts> diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index cf5a856..8001773 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -5,5 +5,6 @@ setup(name='torch_stablesort', py_modules=['torch_stablesort'], ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', ['torch_stablesort.cpp'], - extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], - cmdclass={'build_ext': cpp_extension.BuildExtension}) + extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', + '-fopenmp', '-ggdb', '-std=c++1z'])], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort.cpp b/src/torch_stablesort/torch_stablesort.cpp index 4bc3d6e..bbcac3c 100644 --- a/src/torch_stablesort/torch_stablesort.cpp +++ b/src/torch_stablesort/torch_stablesort.cpp @@ -4,105 +4,8 @@ #include #include -#include "dispatch.h" - -template -struct stable_sort_impl { - std::vector operator()( - torch::Tensor input, - int dim, - torch::optional> out - ) const { - - if (input.is_sparse()) - throw std::runtime_error("Sparse tensors are not supported"); - - if (input.device().type() != torch::DeviceType::CPU) - throw std::runtime_error("Only CPU tensors are supported"); - - if (out != torch::nullopt) - throw std::runtime_error("out argument is not supported"); - - auto in = (dim != -1) ? - torch::transpose(input, dim, -1) : - input; - - auto in_sizes = in.sizes(); - - // std::cout << "in_sizes: " << in_sizes << std::endl; - - in = in.view({ -1, in.size(-1) }).contiguous(); - - auto in_outer_stride = in.stride(-2); - auto in_inner_stride = in.stride(-1); - - auto pin = static_cast(in.data_ptr()); - - auto x = in.clone(); - - auto x_outer_stride = x.stride(-2); - auto x_inner_stride = x.stride(-1); - - auto n_cols = x.size(1); - auto n_rows = x.size(0); - auto px = static_cast(x.data_ptr()); - - auto y = torch::empty({ n_rows, n_cols }, - torch::TensorOptions().dtype(torch::kInt64)); - - auto y_outer_stride = y.stride(-2); - auto y_inner_stride = y.stride(-1); - - auto py = static_cast(y.data_ptr()); - - #pragma omp parallel for - for (decltype(n_rows) i = 0; i < n_rows; i++) { - std::vector indices(n_cols); - for (decltype(n_cols) k = 0; k < n_cols; k++) { - indices[k] = k; - } - - std::stable_sort(std::begin(indices), std::end(indices), - [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { - auto va = pin[i * in_outer_stride + a * in_inner_stride]; - auto vb = pin[i * in_outer_stride + b * in_inner_stride]; - if constexpr(descending) - return (vb < va); - else - return (va < vb); - }); - - for (decltype(n_cols) k = 0; k < n_cols; k++) { - py[i * y_outer_stride + k * y_inner_stride] = indices[k]; - px[i * x_outer_stride + k * x_inner_stride] = - pin[i * in_outer_stride + indices[k] * in_inner_stride]; - } - } - - // std::cout << "Here" << std::endl; - - x = x.view(in_sizes); - y = y.view(in_sizes); - - x = (dim == -1) ? - x : - torch::transpose(x, dim, -1).contiguous(); - - y = (dim == -1) ? - y : - torch::transpose(y, dim, -1).contiguous(); - - // std::cout << "Here 2" << std::endl; - - return { x, y }; - } -}; - -template -struct stable_sort_impl_desc: stable_sort_impl {}; - -template -struct stable_sort_impl_asc: stable_sort_impl {}; +#include "torch_stablesort_cuda.h" +#include "torch_stablesort_cpu.h" std::vector stable_sort( torch::Tensor input, @@ -110,12 +13,14 @@ std::vector stable_sort( bool descending = false, torch::optional> out = torch::nullopt) { - if (descending) - return dispatch>( - input, dim, out); - else - return dispatch>( - input, dim, out); + switch (input.device().type()) { + case torch::DeviceType::CUDA: + return dispatch_cuda(input, dim, descending, out); + case torch::DeviceType::CPU: + return dispatch_cpu(input, dim, descending, out); + default: + throw std::runtime_error("Unsupported device type"); + } } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/src/torch_stablesort/torch_stablesort_cpu.h b/src/torch_stablesort/torch_stablesort_cpu.h new file mode 100644 index 0000000..b2bbce1 --- /dev/null +++ b/src/torch_stablesort/torch_stablesort_cpu.h @@ -0,0 +1,119 @@ +#pragma once + +#include + +#include +#include + +#include "dispatch.h" + +template +struct stable_sort_impl { + std::vector operator()( + torch::Tensor input, + int dim, + torch::optional> out + ) const { + + if (input.is_sparse()) + throw std::runtime_error("Sparse tensors are not supported"); + + if (input.device().type() != torch::DeviceType::CPU) + throw std::runtime_error("Only CPU tensors are supported"); + + if (out != torch::nullopt) + throw std::runtime_error("out argument is not supported"); + + auto in = (dim != -1) ? + torch::transpose(input, dim, -1) : + input; + + auto in_sizes = in.sizes(); + + // std::cout << "in_sizes: " << in_sizes << std::endl; + + in = in.view({ -1, in.size(-1) }).contiguous(); + + auto in_outer_stride = in.stride(-2); + auto in_inner_stride = in.stride(-1); + + auto pin = static_cast(in.data_ptr()); + + auto x = in.clone(); + + auto x_outer_stride = x.stride(-2); + auto x_inner_stride = x.stride(-1); + + auto n_cols = x.size(1); + auto n_rows = x.size(0); + auto px = static_cast(x.data_ptr()); + + auto y = torch::empty({ n_rows, n_cols }, + torch::TensorOptions().dtype(torch::kInt64)); + + auto y_outer_stride = y.stride(-2); + auto y_inner_stride = y.stride(-1); + + auto py = static_cast(y.data_ptr()); + + #pragma omp parallel for + for (decltype(n_rows) i = 0; i < n_rows; i++) { + std::vector indices(n_cols); + for (decltype(n_cols) k = 0; k < n_cols; k++) { + indices[k] = k; + } + + std::stable_sort(std::begin(indices), std::end(indices), + [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { + auto va = pin[i * in_outer_stride + a * in_inner_stride]; + auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + if constexpr(descending) + return (vb < va); + else + return (va < vb); + }); + + for (decltype(n_cols) k = 0; k < n_cols; k++) { + py[i * y_outer_stride + k * y_inner_stride] = indices[k]; + px[i * x_outer_stride + k * x_inner_stride] = + pin[i * in_outer_stride + indices[k] * in_inner_stride]; + } + } + + // std::cout << "Here" << std::endl; + + x = x.view(in_sizes); + y = y.view(in_sizes); + + x = (dim == -1) ? + x : + torch::transpose(x, dim, -1).contiguous(); + + y = (dim == -1) ? + y : + torch::transpose(y, dim, -1).contiguous(); + + // std::cout << "Here 2" << std::endl; + + return { x, y }; + } +}; + +template +struct stable_sort_impl_desc: stable_sort_impl {}; + +template +struct stable_sort_impl_asc: stable_sort_impl {}; + +std::vector dispatch_cpu(torch::Tensor input, + int dim, + bool descending, + torch::optional> out) { + + if (descending) + return dispatch>( + input, dim, out); + else + return dispatch>( + input, dim, out); +} diff --git a/src/torch_stablesort/torch_stablesort_cuda.h b/src/torch_stablesort/torch_stablesort_cuda.h new file mode 100644 index 0000000..496429f --- /dev/null +++ b/src/torch_stablesort/torch_stablesort_cuda.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +#include "dispatch.h" + +template +struct stable_sort_impl_cuda { + std::vector operator()( + torch::Tensor input, + int dim, + torch::optional> out + ) const { + + if (input.is_sparse()) + throw std::runtime_error("Sparse tensors are not supported"); + + if (input.device().type() != torch::DeviceType::CUDA) + throw std::runtime_error("Only CUDA tensors are supported"); + + if (out != torch::nullopt) + throw std::runtime_error("out argument is not supported"); + + auto x = input.clone(); + + if (dim != -1) + x = torch::transpose(x, dim, -1); + + auto x_sizes = x.sizes(); + + x = x.view({ -1, x.size(-1) }).contiguous(); + + auto x_outer_stride = x.stride(-2); + auto x_inner_stride = x.stride(-1); + auto n_cols = x.size(1); + auto n_rows = x.size(0); + auto px = x.data_ptr(); + + assert(x_inner_stride == 1); + + auto y = torch::repeat_interleave( + torch::arange(0, n_cols, 1, torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device())), + torch::ones(n_rows, torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device())) + ); + + auto y_outer_stride = y.stride(-2); + auto y_inner_stride = y.stride(-1); + auto py = y.data_ptr(); + + assert(y_inner_stride == 1); + + for (decltype(n_rows) i = 0; i < n_rows; i++) { + auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); + + auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); + auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + + n_cols * x_inner_stride); + + if constexpr(descending) + thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, + thrust::greater()); + else + thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); + } + + x = x.view(x_sizes); + y = y.view(x_sizes); + + x = (dim == -1) ? + x : + torch::transpose(x, dim, -1).contiguous(); + + y = (dim == -1) ? + y : + torch::transpose(y, dim, -1).contiguous(); + + return { x, y }; + } +}; + +template +struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {}; + +template +struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {}; + +std::vector dispatch_cuda(torch::Tensor input, + int dim, + bool descending, + torch::optional> out) { + + if (descending) + return dispatch>( + input, dim, out); + else + return dispatch>( + input, dim, out); +}