|
|
@@ -29,31 +29,27 @@ struct stable_sort_impl_cuda { |
|
|
|
if (dim != -1)
|
|
|
|
values = torch::transpose(values, dim, -1);
|
|
|
|
|
|
|
|
auto values_sizes = values.sizes();
|
|
|
|
auto orig_sizes = values.sizes();
|
|
|
|
|
|
|
|
values = values.view({ -1, values.size(-1) }).contiguous();
|
|
|
|
|
|
|
|
auto n_cols = values.size(1);
|
|
|
|
auto n_rows = values.size(0);
|
|
|
|
auto n = n_rows * n_cols;
|
|
|
|
|
|
|
|
assert(values.stride(-2) == n_cols);
|
|
|
|
assert(values.stride(-1) == 1);
|
|
|
|
|
|
|
|
auto values_ptr = values.data_ptr<T>();
|
|
|
|
|
|
|
|
auto indices = torch::repeat_interleave(
|
|
|
|
torch::arange(0, n_cols, 1, torch::TensorOptions()
|
|
|
|
auto indices = torch::arange(0, n, 1, torch::TensorOptions()
|
|
|
|
.dtype(torch::kInt64)
|
|
|
|
.device(values.device())).view({ 1, -1 }),
|
|
|
|
n_rows,
|
|
|
|
0 /* dim */
|
|
|
|
);
|
|
|
|
.device(values.device())).view({ n_rows, n_cols });
|
|
|
|
|
|
|
|
assert(indices.stride(-2) == n_cols);
|
|
|
|
assert(indices.stride(-1) == 1);
|
|
|
|
auto indices_ptr = indices.data_ptr<int64_t>();
|
|
|
|
|
|
|
|
auto n = n_rows * n_cols;
|
|
|
|
auto indices_ptr = indices.data_ptr<int64_t>();
|
|
|
|
|
|
|
|
auto ind_beg = thrust::device_pointer_cast(indices_ptr);
|
|
|
|
auto val_beg = thrust::device_pointer_cast(values_ptr);
|
|
|
@@ -74,15 +70,18 @@ struct stable_sort_impl_cuda { |
|
|
|
|
|
|
|
thrust::transform(thrust::device,
|
|
|
|
ind_beg, ind_beg + n, n_cols_iter,
|
|
|
|
segments.begin(), thrust::modulus<int64_t>());
|
|
|
|
segments.begin(), thrust::divides<int64_t>());
|
|
|
|
|
|
|
|
thrust::stable_sort_by_key(thrust::device, segments.begin(),
|
|
|
|
segments.end(), ind_beg);
|
|
|
|
|
|
|
|
thrust::transform(thrust::device, ind_beg, ind_beg + n,
|
|
|
|
n_cols_iter, ind_beg, thrust::modulus<int64_t>());
|
|
|
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
|
|
|
values = values.view(values_sizes);
|
|
|
|
indices = indices.view(values_sizes);
|
|
|
|
values = values.view(orig_sizes);
|
|
|
|
indices = indices.view(orig_sizes);
|
|
|
|
|
|
|
|
if (dim != -1)
|
|
|
|
values = torch::transpose(values, dim, -1).contiguous();
|
|
|
|