diff --git a/src/torch_stablesort/torch_stablesort_cuda.cu b/src/torch_stablesort/torch_stablesort_cuda.cu index d7885d7..e05aa2c 100644 --- a/src/torch_stablesort/torch_stablesort_cuda.cu +++ b/src/torch_stablesort/torch_stablesort_cuda.cu @@ -24,75 +24,73 @@ struct stable_sort_impl_cuda { if (out != torch::nullopt) throw std::runtime_error("out argument is not supported"); - auto x = input.clone(); + auto values = input.clone(); if (dim != -1) - x = torch::transpose(x, dim, -1); + values = torch::transpose(values, dim, -1); - auto x_sizes = x.sizes(); + auto values_sizes = values.sizes(); - x = x.view({ -1, x.size(-1) }).contiguous(); + values = values.view({ -1, values.size(-1) }).contiguous(); - auto x_outer_stride = x.stride(-2); - auto x_inner_stride = x.stride(-1); - auto n_cols = x.size(1); - auto n_rows = x.size(0); - auto px = x.data_ptr(); + auto n_cols = values.size(1); + auto n_rows = values.size(0); - assert(x_inner_stride == 1); + assert(values.stride(-2) == n_cols); + assert(values.stride(-1) == 1); - auto y = torch::repeat_interleave( + auto values_ptr = values.data_ptr(); + + auto indices = torch::repeat_interleave( torch::arange(0, n_cols, 1, torch::TensorOptions() .dtype(torch::kInt64) - .device(x.device())).view({ 1, -1 }), + .device(values.device())).view({ 1, -1 }), n_rows, 0 /* dim */ ); - auto y_outer_stride = y.stride(-2); - auto y_inner_stride = y.stride(-1); - auto py = y.data_ptr(); - - assert(y_inner_stride == 1); - - #define NUM_STREAMS 16 - cudaStream_t streams[NUM_STREAMS]; - for(int i = 0; i < NUM_STREAMS; i++) - assert(cudaStreamCreate(&streams[i]) == cudaSuccess); - - thrust::host_vector row_indices(n_rows); - thrust::sequence(row_indices.begin(), row_indices.end()); - thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(), - [&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) { - auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); - - auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); - auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + - n_cols * x_inner_stride); - - if (descending) - thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg, - thrust::greater()); - else - thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg); - }); - cudaDeviceSynchronize(); + assert(indices.stride(-2) == n_cols); + assert(indices.stride(-1) == 1); + auto indices_ptr = indices.data_ptr(); + + auto n = n_rows * n_cols; + + auto ind_beg = thrust::device_pointer_cast(indices_ptr); + auto val_beg = thrust::device_pointer_cast(values_ptr); + + if (descending) + thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater()); + else + thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg); + + thrust::device_vector segments(n); + thrust::constant_iterator n_cols_iter(n_cols); + thrust::transform(thrust::device, + ind_beg, ind_beg + n, n_cols_iter, + segments.begin(), thrust::divides()); - for(int i = 0; i < NUM_STREAMS; i++) - assert(cudaStreamDestroy(streams[i]) == cudaSuccess); + thrust::stable_sort_by_key(thrust::device, segments.begin(), + segments.end(), val_beg); - x = x.view(x_sizes); - y = y.view(x_sizes); + thrust::transform(thrust::device, + ind_beg, ind_beg + n, n_cols_iter, + segments.begin(), thrust::modulus()); - x = (dim == -1) ? - x : - torch::transpose(x, dim, -1).contiguous(); + thrust::stable_sort_by_key(thrust::device, segments.begin(), + segments.end(), ind_beg); - y = (dim == -1) ? - y : - torch::transpose(y, dim, -1).contiguous(); + cudaDeviceSynchronize(); + + values = values.view(values_sizes); + indices = indices.view(values_sizes); + + if (dim != -1) + values = torch::transpose(values, dim, -1).contiguous(); + + if (dim != -1) + indices = torch::transpose(indices, dim, -1).contiguous(); - return { x, y }; + return { values, indices }; } };