#pragma once #include #include #include #include "dispatch.h" #include "torch_stablesort_cuda.h" template struct stable_sort_impl_cuda { std::vector operator()( torch::Tensor input, int dim, torch::optional> out ) const { if (input.is_sparse()) throw std::runtime_error("Sparse tensors are not supported"); if (input.device().type() != torch::DeviceType::CUDA) throw std::runtime_error("Only CUDA tensors are supported"); if (out != torch::nullopt) throw std::runtime_error("out argument is not supported"); auto x = input.clone(); if (dim != -1) x = torch::transpose(x, dim, -1); auto x_sizes = x.sizes(); x = x.view({ -1, x.size(-1) }).contiguous(); auto x_outer_stride = x.stride(-2); auto x_inner_stride = x.stride(-1); auto n_cols = x.size(1); auto n_rows = x.size(0); auto px = x.data_ptr(); assert(x_inner_stride == 1); auto y = torch::repeat_interleave( torch::arange(0, n_cols, 1, torch::TensorOptions() .dtype(torch::kInt32) .device(x.device())), torch::ones(n_rows, torch::TensorOptions() .dtype(torch::kInt32) .device(x.device())) ); auto y_outer_stride = y.stride(-2); auto y_inner_stride = y.stride(-1); auto py = y.data_ptr(); assert(y_inner_stride == 1); for (decltype(n_rows) i = 0; i < n_rows; i++) { auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + n_cols * x_inner_stride); if constexpr(descending) thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, thrust::greater()); else thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); } x = x.view(x_sizes); y = y.view(x_sizes); x = (dim == -1) ? x : torch::transpose(x, dim, -1).contiguous(); y = (dim == -1) ? y : torch::transpose(y, dim, -1).contiguous(); return { x, y }; } }; template struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {}; template struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {}; std::vector dispatch_cuda(torch::Tensor input, int dim, bool descending, torch::optional> out) { if (descending) return dispatch>( input, dim, out); else return dispatch>( input, dim, out); }