| @@ -24,75 +24,73 @@ struct stable_sort_impl_cuda { | |||||
| if (out != torch::nullopt) | if (out != torch::nullopt) | ||||
| throw std::runtime_error("out argument is not supported"); | throw std::runtime_error("out argument is not supported"); | ||||
| auto x = input.clone(); | |||||
| auto values = input.clone(); | |||||
| if (dim != -1) | if (dim != -1) | ||||
| x = torch::transpose(x, dim, -1); | |||||
| values = torch::transpose(values, dim, -1); | |||||
| auto x_sizes = x.sizes(); | |||||
| auto values_sizes = values.sizes(); | |||||
| x = x.view({ -1, x.size(-1) }).contiguous(); | |||||
| values = values.view({ -1, values.size(-1) }).contiguous(); | |||||
| auto x_outer_stride = x.stride(-2); | |||||
| auto x_inner_stride = x.stride(-1); | |||||
| auto n_cols = x.size(1); | |||||
| auto n_rows = x.size(0); | |||||
| auto px = x.data_ptr<T>(); | |||||
| auto n_cols = values.size(1); | |||||
| auto n_rows = values.size(0); | |||||
| assert(x_inner_stride == 1); | |||||
| assert(values.stride(-2) == n_cols); | |||||
| assert(values.stride(-1) == 1); | |||||
| auto y = torch::repeat_interleave( | |||||
| auto values_ptr = values.data_ptr<T>(); | |||||
| auto indices = torch::repeat_interleave( | |||||
| torch::arange(0, n_cols, 1, torch::TensorOptions() | torch::arange(0, n_cols, 1, torch::TensorOptions() | ||||
| .dtype(torch::kInt64) | .dtype(torch::kInt64) | ||||
| .device(x.device())).view({ 1, -1 }), | |||||
| .device(values.device())).view({ 1, -1 }), | |||||
| n_rows, | n_rows, | ||||
| 0 /* dim */ | 0 /* dim */ | ||||
| ); | ); | ||||
| auto y_outer_stride = y.stride(-2); | |||||
| auto y_inner_stride = y.stride(-1); | |||||
| auto py = y.data_ptr<int64_t>(); | |||||
| assert(y_inner_stride == 1); | |||||
| #define NUM_STREAMS 16 | |||||
| cudaStream_t streams[NUM_STREAMS]; | |||||
| for(int i = 0; i < NUM_STREAMS; i++) | |||||
| assert(cudaStreamCreate(&streams[i]) == cudaSuccess); | |||||
| thrust::host_vector<int64_t> row_indices(n_rows); | |||||
| thrust::sequence(row_indices.begin(), row_indices.end()); | |||||
| thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(), | |||||
| [&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) { | |||||
| auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); | |||||
| auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); | |||||
| auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + | |||||
| n_cols * x_inner_stride); | |||||
| if (descending) | |||||
| thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg, | |||||
| thrust::greater<T>()); | |||||
| else | |||||
| thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg); | |||||
| }); | |||||
| cudaDeviceSynchronize(); | |||||
| assert(indices.stride(-2) == n_cols); | |||||
| assert(indices.stride(-1) == 1); | |||||
| auto indices_ptr = indices.data_ptr<int64_t>(); | |||||
| auto n = n_rows * n_cols; | |||||
| auto ind_beg = thrust::device_pointer_cast(indices_ptr); | |||||
| auto val_beg = thrust::device_pointer_cast(values_ptr); | |||||
| if (descending) | |||||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater<T>()); | |||||
| else | |||||
| thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg); | |||||
| thrust::device_vector<int64_t> segments(n); | |||||
| thrust::constant_iterator<int64_t> n_cols_iter(n_cols); | |||||
| thrust::transform(thrust::device, | |||||
| ind_beg, ind_beg + n, n_cols_iter, | |||||
| segments.begin(), thrust::divides<int64_t>()); | |||||
| for(int i = 0; i < NUM_STREAMS; i++) | |||||
| assert(cudaStreamDestroy(streams[i]) == cudaSuccess); | |||||
| thrust::stable_sort_by_key(thrust::device, segments.begin(), | |||||
| segments.end(), val_beg); | |||||
| x = x.view(x_sizes); | |||||
| y = y.view(x_sizes); | |||||
| thrust::transform(thrust::device, | |||||
| ind_beg, ind_beg + n, n_cols_iter, | |||||
| segments.begin(), thrust::modulus<int64_t>()); | |||||
| x = (dim == -1) ? | |||||
| x : | |||||
| torch::transpose(x, dim, -1).contiguous(); | |||||
| thrust::stable_sort_by_key(thrust::device, segments.begin(), | |||||
| segments.end(), ind_beg); | |||||
| y = (dim == -1) ? | |||||
| y : | |||||
| torch::transpose(y, dim, -1).contiguous(); | |||||
| cudaDeviceSynchronize(); | |||||
| values = values.view(values_sizes); | |||||
| indices = indices.view(values_sizes); | |||||
| if (dim != -1) | |||||
| values = torch::transpose(values, dim, -1).contiguous(); | |||||
| if (dim != -1) | |||||
| indices = torch::transpose(indices, dim, -1).contiguous(); | |||||
| return { x, y }; | |||||
| return { values, indices }; | |||||
| } | } | ||||
| }; | }; | ||||