|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- #pragma once
-
- #include <thrust/sort.h>
- #include <thrust/device_ptr.h>
- #include <thrust/execution_policy.h>
-
- #include "dispatch.h"
- #include "torch_stablesort_cuda.h"
-
- template<bool descending, typename T>
- struct stable_sort_impl_cuda {
- std::vector<torch::Tensor> operator()(
- torch::Tensor input,
- int dim,
- torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
- ) const {
-
- if (input.is_sparse())
- throw std::runtime_error("Sparse tensors are not supported");
-
- if (input.device().type() != torch::DeviceType::CUDA)
- throw std::runtime_error("Only CUDA tensors are supported");
-
- if (out != torch::nullopt)
- throw std::runtime_error("out argument is not supported");
-
- auto values = input.clone();
-
- if (dim != -1)
- values = torch::transpose(values, dim, -1);
-
- auto values_sizes = values.sizes();
-
- values = values.view({ -1, values.size(-1) }).contiguous();
-
- auto n_cols = values.size(1);
- auto n_rows = values.size(0);
-
- assert(values.stride(-2) == n_cols);
- assert(values.stride(-1) == 1);
-
- auto values_ptr = values.data_ptr<T>();
-
- auto indices = torch::repeat_interleave(
- torch::arange(0, n_cols, 1, torch::TensorOptions()
- .dtype(torch::kInt64)
- .device(values.device())).view({ 1, -1 }),
- n_rows,
- 0 /* dim */
- );
-
- assert(indices.stride(-2) == n_cols);
- assert(indices.stride(-1) == 1);
- auto indices_ptr = indices.data_ptr<int64_t>();
-
- auto n = n_rows * n_cols;
-
- auto ind_beg = thrust::device_pointer_cast(indices_ptr);
- auto val_beg = thrust::device_pointer_cast(values_ptr);
-
- if (descending)
- thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater<T>());
- else
- thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg);
-
- thrust::device_vector<int64_t> segments(n);
- thrust::constant_iterator<int64_t> n_cols_iter(n_cols);
- thrust::transform(thrust::device,
- ind_beg, ind_beg + n, n_cols_iter,
- segments.begin(), thrust::divides<int64_t>());
-
- thrust::stable_sort_by_key(thrust::device, segments.begin(),
- segments.end(), val_beg);
-
- thrust::transform(thrust::device,
- ind_beg, ind_beg + n, n_cols_iter,
- segments.begin(), thrust::modulus<int64_t>());
-
- thrust::stable_sort_by_key(thrust::device, segments.begin(),
- segments.end(), ind_beg);
-
- cudaDeviceSynchronize();
-
- values = values.view(values_sizes);
- indices = indices.view(values_sizes);
-
- if (dim != -1)
- values = torch::transpose(values, dim, -1).contiguous();
-
- if (dim != -1)
- indices = torch::transpose(indices, dim, -1).contiguous();
-
- return { values, indices };
- }
- };
-
- template <typename T>
- struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {};
-
- template <typename T>
- struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {};
-
- std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input,
- int dim,
- bool descending,
- torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) {
-
- if (descending)
- return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>(
- input, dim, out);
- else
- return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>(
- input, dim, out);
- }
|