diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index ab9412c..ae03d78 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -8,6 +8,7 @@ setup(name='torch_stablesort', extra_compile_args={ 'cxx': ['-fopenmp', '-ggdb', '-std=c++1z'], 'nvcc': [ '-I/pstore/home/adaszews/scratch/thrust', - '-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc', '-std=c++14'] + '-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc', + '-std=c++14', '--expt-extended-lambda', '-O99'] } ) ], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort_cuda.cu b/src/torch_stablesort/torch_stablesort_cuda.cu index bafa9ee..d7885d7 100644 --- a/src/torch_stablesort/torch_stablesort_cuda.cu +++ b/src/torch_stablesort/torch_stablesort_cuda.cu @@ -43,32 +43,43 @@ struct stable_sort_impl_cuda { auto y = torch::repeat_interleave( torch::arange(0, n_cols, 1, torch::TensorOptions() - .dtype(torch::kInt32) - .device(x.device())), - torch::ones(n_rows, torch::TensorOptions() - .dtype(torch::kInt32) - .device(x.device())) + .dtype(torch::kInt64) + .device(x.device())).view({ 1, -1 }), + n_rows, + 0 /* dim */ ); auto y_outer_stride = y.stride(-2); auto y_inner_stride = y.stride(-1); - auto py = y.data_ptr(); + auto py = y.data_ptr(); assert(y_inner_stride == 1); - for (decltype(n_rows) i = 0; i < n_rows; i++) { - auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); - - auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); - auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + - n_cols * x_inner_stride); - - if constexpr(descending) - thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, - thrust::greater()); - else - thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); - } + #define NUM_STREAMS 16 + cudaStream_t streams[NUM_STREAMS]; + for(int i = 0; i < NUM_STREAMS; i++) + assert(cudaStreamCreate(&streams[i]) == cudaSuccess); + + thrust::host_vector row_indices(n_rows); + thrust::sequence(row_indices.begin(), row_indices.end()); + thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(), + [&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) { + auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); + + auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); + auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + + n_cols * x_inner_stride); + + if (descending) + thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg, + thrust::greater()); + else + thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg); + }); + cudaDeviceSynchronize(); + + for(int i = 0; i < NUM_STREAMS; i++) + assert(cudaStreamDestroy(streams[i]) == cudaSuccess); x = x.view(x_sizes); y = y.view(x_sizes);