#pragma once #include #include #include #include "dispatch.h" #include "torch_stablesort_cuda.h" template struct stable_sort_impl_cuda { std::vector operator()( torch::Tensor input, int dim, torch::optional> out ) const { if (input.is_sparse()) throw std::runtime_error("Sparse tensors are not supported"); if (input.device().type() != torch::DeviceType::CUDA) throw std::runtime_error("Only CUDA tensors are supported"); if (out != torch::nullopt) throw std::runtime_error("out argument is not supported"); auto values = input.clone(); if (dim != -1) values = torch::transpose(values, dim, -1); auto orig_sizes = values.sizes(); values = values.view({ -1, values.size(-1) }).contiguous(); auto n_cols = values.size(1); auto n_rows = values.size(0); auto n = n_rows * n_cols; assert(values.stride(-2) == n_cols); assert(values.stride(-1) == 1); auto values_ptr = values.data_ptr(); auto indices = torch::arange(0, n, 1, torch::TensorOptions() .dtype(torch::kInt64) .device(values.device())).view({ n_rows, n_cols }); assert(indices.stride(-2) == n_cols); assert(indices.stride(-1) == 1); auto indices_ptr = indices.data_ptr(); auto ind_beg = thrust::device_pointer_cast(indices_ptr); auto val_beg = thrust::device_pointer_cast(values_ptr); if (descending) thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater()); else thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg); thrust::device_vector segments(n); thrust::constant_iterator n_cols_iter(n_cols); thrust::transform(thrust::device, ind_beg, ind_beg + n, n_cols_iter, segments.begin(), thrust::divides()); thrust::stable_sort_by_key(thrust::device, segments.begin(), segments.end(), val_beg); thrust::transform(thrust::device, ind_beg, ind_beg + n, n_cols_iter, segments.begin(), thrust::divides()); thrust::stable_sort_by_key(thrust::device, segments.begin(), segments.end(), ind_beg); thrust::transform(thrust::device, ind_beg, ind_beg + n, n_cols_iter, ind_beg, thrust::modulus()); cudaDeviceSynchronize(); values = values.view(orig_sizes); indices = indices.view(orig_sizes); if (dim != -1) values = torch::transpose(values, dim, -1).contiguous(); if (dim != -1) indices = torch::transpose(indices, dim, -1).contiguous(); return { values, indices }; } }; template struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {}; template struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {}; std::vector dispatch_cuda(torch::Tensor input, int dim, bool descending, torch::optional> out) { if (descending) return dispatch>( input, dim, out); else return dispatch>( input, dim, out); }