| @@ -1,3 +1,5 @@ | |||||
| #pragma once | |||||
| #include <utility> | #include <utility> | ||||
| template<template<typename T> class F, typename R, typename... Ts> | template<template<typename T> class F, typename R, typename... Ts> | ||||
| @@ -5,5 +5,6 @@ setup(name='torch_stablesort', | |||||
| py_modules=['torch_stablesort'], | py_modules=['torch_stablesort'], | ||||
| ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ||||
| ['torch_stablesort.cpp'], | ['torch_stablesort.cpp'], | ||||
| extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], | |||||
| cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||||
| extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', | |||||
| '-fopenmp', '-ggdb', '-std=c++1z'])], | |||||
| cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||||
| @@ -4,105 +4,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "dispatch.h" | |||||
| template<bool descending, typename T> | |||||
| struct stable_sort_impl { | |||||
| std::vector<torch::Tensor> operator()( | |||||
| torch::Tensor input, | |||||
| int dim, | |||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
| ) const { | |||||
| if (input.is_sparse()) | |||||
| throw std::runtime_error("Sparse tensors are not supported"); | |||||
| if (input.device().type() != torch::DeviceType::CPU) | |||||
| throw std::runtime_error("Only CPU tensors are supported"); | |||||
| if (out != torch::nullopt) | |||||
| throw std::runtime_error("out argument is not supported"); | |||||
| auto in = (dim != -1) ? | |||||
| torch::transpose(input, dim, -1) : | |||||
| input; | |||||
| auto in_sizes = in.sizes(); | |||||
| // std::cout << "in_sizes: " << in_sizes << std::endl; | |||||
| in = in.view({ -1, in.size(-1) }).contiguous(); | |||||
| auto in_outer_stride = in.stride(-2); | |||||
| auto in_inner_stride = in.stride(-1); | |||||
| auto pin = static_cast<T*>(in.data_ptr()); | |||||
| auto x = in.clone(); | |||||
| auto x_outer_stride = x.stride(-2); | |||||
| auto x_inner_stride = x.stride(-1); | |||||
| auto n_cols = x.size(1); | |||||
| auto n_rows = x.size(0); | |||||
| auto px = static_cast<T*>(x.data_ptr()); | |||||
| auto y = torch::empty({ n_rows, n_cols }, | |||||
| torch::TensorOptions().dtype(torch::kInt64)); | |||||
| auto y_outer_stride = y.stride(-2); | |||||
| auto y_inner_stride = y.stride(-1); | |||||
| auto py = static_cast<int64_t*>(y.data_ptr()); | |||||
| #pragma omp parallel for | |||||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
| std::vector<int64_t> indices(n_cols); | |||||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
| indices[k] = k; | |||||
| } | |||||
| std::stable_sort(std::begin(indices), std::end(indices), | |||||
| [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||||
| auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||||
| auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||||
| if constexpr(descending) | |||||
| return (vb < va); | |||||
| else | |||||
| return (va < vb); | |||||
| }); | |||||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
| py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||||
| px[i * x_outer_stride + k * x_inner_stride] = | |||||
| pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||||
| } | |||||
| } | |||||
| // std::cout << "Here" << std::endl; | |||||
| x = x.view(in_sizes); | |||||
| y = y.view(in_sizes); | |||||
| x = (dim == -1) ? | |||||
| x : | |||||
| torch::transpose(x, dim, -1).contiguous(); | |||||
| y = (dim == -1) ? | |||||
| y : | |||||
| torch::transpose(y, dim, -1).contiguous(); | |||||
| // std::cout << "Here 2" << std::endl; | |||||
| return { x, y }; | |||||
| } | |||||
| }; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||||
| #include "torch_stablesort_cuda.h" | |||||
| #include "torch_stablesort_cpu.h" | |||||
| std::vector<torch::Tensor> stable_sort( | std::vector<torch::Tensor> stable_sort( | ||||
| torch::Tensor input, | torch::Tensor input, | ||||
| @@ -110,12 +13,14 @@ std::vector<torch::Tensor> stable_sort( | |||||
| bool descending = false, | bool descending = false, | ||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | ||||
| if (descending) | |||||
| return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| else | |||||
| return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| switch (input.device().type()) { | |||||
| case torch::DeviceType::CUDA: | |||||
| return dispatch_cuda(input, dim, descending, out); | |||||
| case torch::DeviceType::CPU: | |||||
| return dispatch_cpu(input, dim, descending, out); | |||||
| default: | |||||
| throw std::runtime_error("Unsupported device type"); | |||||
| } | |||||
| } | } | ||||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
| @@ -0,0 +1,119 @@ | |||||
| #pragma once | |||||
| #include <torch/extension.h> | |||||
| #include <vector> | |||||
| #include <tuple> | |||||
| #include "dispatch.h" | |||||
| template<bool descending, typename T> | |||||
| struct stable_sort_impl { | |||||
| std::vector<torch::Tensor> operator()( | |||||
| torch::Tensor input, | |||||
| int dim, | |||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
| ) const { | |||||
| if (input.is_sparse()) | |||||
| throw std::runtime_error("Sparse tensors are not supported"); | |||||
| if (input.device().type() != torch::DeviceType::CPU) | |||||
| throw std::runtime_error("Only CPU tensors are supported"); | |||||
| if (out != torch::nullopt) | |||||
| throw std::runtime_error("out argument is not supported"); | |||||
| auto in = (dim != -1) ? | |||||
| torch::transpose(input, dim, -1) : | |||||
| input; | |||||
| auto in_sizes = in.sizes(); | |||||
| // std::cout << "in_sizes: " << in_sizes << std::endl; | |||||
| in = in.view({ -1, in.size(-1) }).contiguous(); | |||||
| auto in_outer_stride = in.stride(-2); | |||||
| auto in_inner_stride = in.stride(-1); | |||||
| auto pin = static_cast<T*>(in.data_ptr()); | |||||
| auto x = in.clone(); | |||||
| auto x_outer_stride = x.stride(-2); | |||||
| auto x_inner_stride = x.stride(-1); | |||||
| auto n_cols = x.size(1); | |||||
| auto n_rows = x.size(0); | |||||
| auto px = static_cast<T*>(x.data_ptr()); | |||||
| auto y = torch::empty({ n_rows, n_cols }, | |||||
| torch::TensorOptions().dtype(torch::kInt64)); | |||||
| auto y_outer_stride = y.stride(-2); | |||||
| auto y_inner_stride = y.stride(-1); | |||||
| auto py = static_cast<int64_t*>(y.data_ptr()); | |||||
| #pragma omp parallel for | |||||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
| std::vector<int64_t> indices(n_cols); | |||||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
| indices[k] = k; | |||||
| } | |||||
| std::stable_sort(std::begin(indices), std::end(indices), | |||||
| [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||||
| auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||||
| auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||||
| if constexpr(descending) | |||||
| return (vb < va); | |||||
| else | |||||
| return (va < vb); | |||||
| }); | |||||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
| py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||||
| px[i * x_outer_stride + k * x_inner_stride] = | |||||
| pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||||
| } | |||||
| } | |||||
| // std::cout << "Here" << std::endl; | |||||
| x = x.view(in_sizes); | |||||
| y = y.view(in_sizes); | |||||
| x = (dim == -1) ? | |||||
| x : | |||||
| torch::transpose(x, dim, -1).contiguous(); | |||||
| y = (dim == -1) ? | |||||
| y : | |||||
| torch::transpose(y, dim, -1).contiguous(); | |||||
| // std::cout << "Here 2" << std::endl; | |||||
| return { x, y }; | |||||
| } | |||||
| }; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||||
| std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input, | |||||
| int dim, | |||||
| bool descending, | |||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||||
| if (descending) | |||||
| return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| else | |||||
| return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| } | |||||
| @@ -0,0 +1,109 @@ | |||||
| #pragma once | |||||
| #include <torch/extension.h> | |||||
| #include <thrust/sort.h> | |||||
| #include <thrust/device_ptr.h> | |||||
| #include <thrust/execution_policy.h> | |||||
| #include <vector> | |||||
| #include <tuple> | |||||
| #include "dispatch.h" | |||||
| template<bool descending, typename T> | |||||
| struct stable_sort_impl_cuda { | |||||
| std::vector<torch::Tensor> operator()( | |||||
| torch::Tensor input, | |||||
| int dim, | |||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
| ) const { | |||||
| if (input.is_sparse()) | |||||
| throw std::runtime_error("Sparse tensors are not supported"); | |||||
| if (input.device().type() != torch::DeviceType::CUDA) | |||||
| throw std::runtime_error("Only CUDA tensors are supported"); | |||||
| if (out != torch::nullopt) | |||||
| throw std::runtime_error("out argument is not supported"); | |||||
| auto x = input.clone(); | |||||
| if (dim != -1) | |||||
| x = torch::transpose(x, dim, -1); | |||||
| auto x_sizes = x.sizes(); | |||||
| x = x.view({ -1, x.size(-1) }).contiguous(); | |||||
| auto x_outer_stride = x.stride(-2); | |||||
| auto x_inner_stride = x.stride(-1); | |||||
| auto n_cols = x.size(1); | |||||
| auto n_rows = x.size(0); | |||||
| auto px = x.data_ptr<T>(); | |||||
| assert(x_inner_stride == 1); | |||||
| auto y = torch::repeat_interleave( | |||||
| torch::arange(0, n_cols, 1, torch::TensorOptions() | |||||
| .dtype(torch::kInt32) | |||||
| .device(x.device())), | |||||
| torch::ones(n_rows, torch::TensorOptions() | |||||
| .dtype(torch::kInt32) | |||||
| .device(x.device())) | |||||
| ); | |||||
| auto y_outer_stride = y.stride(-2); | |||||
| auto y_inner_stride = y.stride(-1); | |||||
| auto py = y.data_ptr<int32_t>(); | |||||
| assert(y_inner_stride == 1); | |||||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
| auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); | |||||
| auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); | |||||
| auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + | |||||
| n_cols * x_inner_stride); | |||||
| if constexpr(descending) | |||||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, | |||||
| thrust::greater<T>()); | |||||
| else | |||||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); | |||||
| } | |||||
| x = x.view(x_sizes); | |||||
| y = y.view(x_sizes); | |||||
| x = (dim == -1) ? | |||||
| x : | |||||
| torch::transpose(x, dim, -1).contiguous(); | |||||
| y = (dim == -1) ? | |||||
| y : | |||||
| torch::transpose(y, dim, -1).contiguous(); | |||||
| return { x, y }; | |||||
| } | |||||
| }; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {}; | |||||
| template <typename T> | |||||
| struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {}; | |||||
| std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input, | |||||
| int dim, | |||||
| bool descending, | |||||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||||
| if (descending) | |||||
| return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| else | |||||
| return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>( | |||||
| input, dim, out); | |||||
| } | |||||