| @@ -1,3 +1,5 @@ | |||
| #pragma once | |||
| #include <utility> | |||
| template<template<typename T> class F, typename R, typename... Ts> | |||
| @@ -5,5 +5,6 @@ setup(name='torch_stablesort', | |||
| py_modules=['torch_stablesort'], | |||
| ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | |||
| ['torch_stablesort.cpp'], | |||
| extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], | |||
| cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||
| extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', | |||
| '-fopenmp', '-ggdb', '-std=c++1z'])], | |||
| cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||
| @@ -4,105 +4,8 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "dispatch.h" | |||
| template<bool descending, typename T> | |||
| struct stable_sort_impl { | |||
| std::vector<torch::Tensor> operator()( | |||
| torch::Tensor input, | |||
| int dim, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
| ) const { | |||
| if (input.is_sparse()) | |||
| throw std::runtime_error("Sparse tensors are not supported"); | |||
| if (input.device().type() != torch::DeviceType::CPU) | |||
| throw std::runtime_error("Only CPU tensors are supported"); | |||
| if (out != torch::nullopt) | |||
| throw std::runtime_error("out argument is not supported"); | |||
| auto in = (dim != -1) ? | |||
| torch::transpose(input, dim, -1) : | |||
| input; | |||
| auto in_sizes = in.sizes(); | |||
| // std::cout << "in_sizes: " << in_sizes << std::endl; | |||
| in = in.view({ -1, in.size(-1) }).contiguous(); | |||
| auto in_outer_stride = in.stride(-2); | |||
| auto in_inner_stride = in.stride(-1); | |||
| auto pin = static_cast<T*>(in.data_ptr()); | |||
| auto x = in.clone(); | |||
| auto x_outer_stride = x.stride(-2); | |||
| auto x_inner_stride = x.stride(-1); | |||
| auto n_cols = x.size(1); | |||
| auto n_rows = x.size(0); | |||
| auto px = static_cast<T*>(x.data_ptr()); | |||
| auto y = torch::empty({ n_rows, n_cols }, | |||
| torch::TensorOptions().dtype(torch::kInt64)); | |||
| auto y_outer_stride = y.stride(-2); | |||
| auto y_inner_stride = y.stride(-1); | |||
| auto py = static_cast<int64_t*>(y.data_ptr()); | |||
| #pragma omp parallel for | |||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
| std::vector<int64_t> indices(n_cols); | |||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
| indices[k] = k; | |||
| } | |||
| std::stable_sort(std::begin(indices), std::end(indices), | |||
| [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||
| auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||
| auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||
| if constexpr(descending) | |||
| return (vb < va); | |||
| else | |||
| return (va < vb); | |||
| }); | |||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
| py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||
| px[i * x_outer_stride + k * x_inner_stride] = | |||
| pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||
| } | |||
| } | |||
| // std::cout << "Here" << std::endl; | |||
| x = x.view(in_sizes); | |||
| y = y.view(in_sizes); | |||
| x = (dim == -1) ? | |||
| x : | |||
| torch::transpose(x, dim, -1).contiguous(); | |||
| y = (dim == -1) ? | |||
| y : | |||
| torch::transpose(y, dim, -1).contiguous(); | |||
| // std::cout << "Here 2" << std::endl; | |||
| return { x, y }; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||
| template <typename T> | |||
| struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||
| #include "torch_stablesort_cuda.h" | |||
| #include "torch_stablesort_cpu.h" | |||
| std::vector<torch::Tensor> stable_sort( | |||
| torch::Tensor input, | |||
| @@ -110,12 +13,14 @@ std::vector<torch::Tensor> stable_sort( | |||
| bool descending = false, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | |||
| if (descending) | |||
| return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| else | |||
| return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| switch (input.device().type()) { | |||
| case torch::DeviceType::CUDA: | |||
| return dispatch_cuda(input, dim, descending, out); | |||
| case torch::DeviceType::CPU: | |||
| return dispatch_cpu(input, dim, descending, out); | |||
| default: | |||
| throw std::runtime_error("Unsupported device type"); | |||
| } | |||
| } | |||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||
| @@ -0,0 +1,119 @@ | |||
| #pragma once | |||
| #include <torch/extension.h> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include "dispatch.h" | |||
| template<bool descending, typename T> | |||
| struct stable_sort_impl { | |||
| std::vector<torch::Tensor> operator()( | |||
| torch::Tensor input, | |||
| int dim, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
| ) const { | |||
| if (input.is_sparse()) | |||
| throw std::runtime_error("Sparse tensors are not supported"); | |||
| if (input.device().type() != torch::DeviceType::CPU) | |||
| throw std::runtime_error("Only CPU tensors are supported"); | |||
| if (out != torch::nullopt) | |||
| throw std::runtime_error("out argument is not supported"); | |||
| auto in = (dim != -1) ? | |||
| torch::transpose(input, dim, -1) : | |||
| input; | |||
| auto in_sizes = in.sizes(); | |||
| // std::cout << "in_sizes: " << in_sizes << std::endl; | |||
| in = in.view({ -1, in.size(-1) }).contiguous(); | |||
| auto in_outer_stride = in.stride(-2); | |||
| auto in_inner_stride = in.stride(-1); | |||
| auto pin = static_cast<T*>(in.data_ptr()); | |||
| auto x = in.clone(); | |||
| auto x_outer_stride = x.stride(-2); | |||
| auto x_inner_stride = x.stride(-1); | |||
| auto n_cols = x.size(1); | |||
| auto n_rows = x.size(0); | |||
| auto px = static_cast<T*>(x.data_ptr()); | |||
| auto y = torch::empty({ n_rows, n_cols }, | |||
| torch::TensorOptions().dtype(torch::kInt64)); | |||
| auto y_outer_stride = y.stride(-2); | |||
| auto y_inner_stride = y.stride(-1); | |||
| auto py = static_cast<int64_t*>(y.data_ptr()); | |||
| #pragma omp parallel for | |||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
| std::vector<int64_t> indices(n_cols); | |||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
| indices[k] = k; | |||
| } | |||
| std::stable_sort(std::begin(indices), std::end(indices), | |||
| [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||
| auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||
| auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||
| if constexpr(descending) | |||
| return (vb < va); | |||
| else | |||
| return (va < vb); | |||
| }); | |||
| for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
| py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||
| px[i * x_outer_stride + k * x_inner_stride] = | |||
| pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||
| } | |||
| } | |||
| // std::cout << "Here" << std::endl; | |||
| x = x.view(in_sizes); | |||
| y = y.view(in_sizes); | |||
| x = (dim == -1) ? | |||
| x : | |||
| torch::transpose(x, dim, -1).contiguous(); | |||
| y = (dim == -1) ? | |||
| y : | |||
| torch::transpose(y, dim, -1).contiguous(); | |||
| // std::cout << "Here 2" << std::endl; | |||
| return { x, y }; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||
| template <typename T> | |||
| struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||
| std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input, | |||
| int dim, | |||
| bool descending, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||
| if (descending) | |||
| return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| else | |||
| return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| } | |||
| @@ -0,0 +1,109 @@ | |||
| #pragma once | |||
| #include <torch/extension.h> | |||
| #include <thrust/sort.h> | |||
| #include <thrust/device_ptr.h> | |||
| #include <thrust/execution_policy.h> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include "dispatch.h" | |||
| template<bool descending, typename T> | |||
| struct stable_sort_impl_cuda { | |||
| std::vector<torch::Tensor> operator()( | |||
| torch::Tensor input, | |||
| int dim, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
| ) const { | |||
| if (input.is_sparse()) | |||
| throw std::runtime_error("Sparse tensors are not supported"); | |||
| if (input.device().type() != torch::DeviceType::CUDA) | |||
| throw std::runtime_error("Only CUDA tensors are supported"); | |||
| if (out != torch::nullopt) | |||
| throw std::runtime_error("out argument is not supported"); | |||
| auto x = input.clone(); | |||
| if (dim != -1) | |||
| x = torch::transpose(x, dim, -1); | |||
| auto x_sizes = x.sizes(); | |||
| x = x.view({ -1, x.size(-1) }).contiguous(); | |||
| auto x_outer_stride = x.stride(-2); | |||
| auto x_inner_stride = x.stride(-1); | |||
| auto n_cols = x.size(1); | |||
| auto n_rows = x.size(0); | |||
| auto px = x.data_ptr<T>(); | |||
| assert(x_inner_stride == 1); | |||
| auto y = torch::repeat_interleave( | |||
| torch::arange(0, n_cols, 1, torch::TensorOptions() | |||
| .dtype(torch::kInt32) | |||
| .device(x.device())), | |||
| torch::ones(n_rows, torch::TensorOptions() | |||
| .dtype(torch::kInt32) | |||
| .device(x.device())) | |||
| ); | |||
| auto y_outer_stride = y.stride(-2); | |||
| auto y_inner_stride = y.stride(-1); | |||
| auto py = y.data_ptr<int32_t>(); | |||
| assert(y_inner_stride == 1); | |||
| for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
| auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); | |||
| auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); | |||
| auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + | |||
| n_cols * x_inner_stride); | |||
| if constexpr(descending) | |||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, | |||
| thrust::greater<T>()); | |||
| else | |||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); | |||
| } | |||
| x = x.view(x_sizes); | |||
| y = y.view(x_sizes); | |||
| x = (dim == -1) ? | |||
| x : | |||
| torch::transpose(x, dim, -1).contiguous(); | |||
| y = (dim == -1) ? | |||
| y : | |||
| torch::transpose(y, dim, -1).contiguous(); | |||
| return { x, y }; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {}; | |||
| template <typename T> | |||
| struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {}; | |||
| std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input, | |||
| int dim, | |||
| bool descending, | |||
| torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||
| if (descending) | |||
| return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| else | |||
| return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>( | |||
| input, dim, out); | |||
| } | |||