From 5cf1e3d1b594cdbeff88c9e97268449e05089cac Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 1 Sep 2020 18:32:58 +0200 Subject: [PATCH] torch_stablesort CUDA version works and is fast. --- src/torch_stablesort/torch_stablesort_cuda.cu | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/torch_stablesort/torch_stablesort_cuda.cu b/src/torch_stablesort/torch_stablesort_cuda.cu index e05aa2c..cf9b0d7 100644 --- a/src/torch_stablesort/torch_stablesort_cuda.cu +++ b/src/torch_stablesort/torch_stablesort_cuda.cu @@ -29,31 +29,27 @@ struct stable_sort_impl_cuda { if (dim != -1) values = torch::transpose(values, dim, -1); - auto values_sizes = values.sizes(); + auto orig_sizes = values.sizes(); values = values.view({ -1, values.size(-1) }).contiguous(); auto n_cols = values.size(1); auto n_rows = values.size(0); + auto n = n_rows * n_cols; assert(values.stride(-2) == n_cols); assert(values.stride(-1) == 1); auto values_ptr = values.data_ptr(); - auto indices = torch::repeat_interleave( - torch::arange(0, n_cols, 1, torch::TensorOptions() + auto indices = torch::arange(0, n, 1, torch::TensorOptions() .dtype(torch::kInt64) - .device(values.device())).view({ 1, -1 }), - n_rows, - 0 /* dim */ - ); + .device(values.device())).view({ n_rows, n_cols }); assert(indices.stride(-2) == n_cols); assert(indices.stride(-1) == 1); - auto indices_ptr = indices.data_ptr(); - auto n = n_rows * n_cols; + auto indices_ptr = indices.data_ptr(); auto ind_beg = thrust::device_pointer_cast(indices_ptr); auto val_beg = thrust::device_pointer_cast(values_ptr); @@ -74,15 +70,18 @@ struct stable_sort_impl_cuda { thrust::transform(thrust::device, ind_beg, ind_beg + n, n_cols_iter, - segments.begin(), thrust::modulus()); + segments.begin(), thrust::divides()); thrust::stable_sort_by_key(thrust::device, segments.begin(), segments.end(), ind_beg); + thrust::transform(thrust::device, ind_beg, ind_beg + n, + n_cols_iter, ind_beg, thrust::modulus()); + cudaDeviceSynchronize(); - values = values.view(values_sizes); - indices = indices.view(values_sizes); + values = values.view(orig_sizes); + indices = indices.view(orig_sizes); if (dim != -1) values = torch::transpose(values, dim, -1).contiguous();