|
@@ -43,32 +43,43 @@ struct stable_sort_impl_cuda { |
|
|
|
|
|
|
|
|
auto y = torch::repeat_interleave(
|
|
|
auto y = torch::repeat_interleave(
|
|
|
torch::arange(0, n_cols, 1, torch::TensorOptions()
|
|
|
torch::arange(0, n_cols, 1, torch::TensorOptions()
|
|
|
.dtype(torch::kInt32)
|
|
|
|
|
|
.device(x.device())),
|
|
|
|
|
|
torch::ones(n_rows, torch::TensorOptions()
|
|
|
|
|
|
.dtype(torch::kInt32)
|
|
|
|
|
|
.device(x.device()))
|
|
|
|
|
|
|
|
|
.dtype(torch::kInt64)
|
|
|
|
|
|
.device(x.device())).view({ 1, -1 }),
|
|
|
|
|
|
n_rows,
|
|
|
|
|
|
0 /* dim */
|
|
|
);
|
|
|
);
|
|
|
|
|
|
|
|
|
auto y_outer_stride = y.stride(-2);
|
|
|
auto y_outer_stride = y.stride(-2);
|
|
|
auto y_inner_stride = y.stride(-1);
|
|
|
auto y_inner_stride = y.stride(-1);
|
|
|
auto py = y.data_ptr<int32_t>();
|
|
|
|
|
|
|
|
|
auto py = y.data_ptr<int64_t>();
|
|
|
|
|
|
|
|
|
assert(y_inner_stride == 1);
|
|
|
assert(y_inner_stride == 1);
|
|
|
|
|
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
|
|
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
|
|
|
|
|
|
|
|
|
|
|
|
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
|
|
|
|
|
|
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
|
|
|
|
|
|
n_cols * x_inner_stride);
|
|
|
|
|
|
|
|
|
|
|
|
if constexpr(descending)
|
|
|
|
|
|
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg,
|
|
|
|
|
|
thrust::greater<T>());
|
|
|
|
|
|
else
|
|
|
|
|
|
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
#define NUM_STREAMS 16
|
|
|
|
|
|
cudaStream_t streams[NUM_STREAMS];
|
|
|
|
|
|
for(int i = 0; i < NUM_STREAMS; i++)
|
|
|
|
|
|
assert(cudaStreamCreate(&streams[i]) == cudaSuccess);
|
|
|
|
|
|
|
|
|
|
|
|
thrust::host_vector<int64_t> row_indices(n_rows);
|
|
|
|
|
|
thrust::sequence(row_indices.begin(), row_indices.end());
|
|
|
|
|
|
thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(),
|
|
|
|
|
|
[&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) {
|
|
|
|
|
|
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
|
|
|
|
|
|
|
|
|
|
|
|
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
|
|
|
|
|
|
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
|
|
|
|
|
|
n_cols * x_inner_stride);
|
|
|
|
|
|
|
|
|
|
|
|
if (descending)
|
|
|
|
|
|
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg,
|
|
|
|
|
|
thrust::greater<T>());
|
|
|
|
|
|
else
|
|
|
|
|
|
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg);
|
|
|
|
|
|
});
|
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
|
|
|
|
|
|
|
for(int i = 0; i < NUM_STREAMS; i++)
|
|
|
|
|
|
assert(cudaStreamDestroy(streams[i]) == cudaSuccess);
|
|
|
|
|
|
|
|
|
x = x.view(x_sizes);
|
|
|
x = x.view(x_sizes);
|
|
|
y = y.view(x_sizes);
|
|
|
y = y.view(x_sizes);
|
|
|