| @@ -29,31 +29,27 @@ struct stable_sort_impl_cuda { | |||||
| if (dim != -1) | if (dim != -1) | ||||
| values = torch::transpose(values, dim, -1); | values = torch::transpose(values, dim, -1); | ||||
| auto values_sizes = values.sizes(); | |||||
| auto orig_sizes = values.sizes(); | |||||
| values = values.view({ -1, values.size(-1) }).contiguous(); | values = values.view({ -1, values.size(-1) }).contiguous(); | ||||
| auto n_cols = values.size(1); | auto n_cols = values.size(1); | ||||
| auto n_rows = values.size(0); | auto n_rows = values.size(0); | ||||
| auto n = n_rows * n_cols; | |||||
| assert(values.stride(-2) == n_cols); | assert(values.stride(-2) == n_cols); | ||||
| assert(values.stride(-1) == 1); | assert(values.stride(-1) == 1); | ||||
| auto values_ptr = values.data_ptr<T>(); | auto values_ptr = values.data_ptr<T>(); | ||||
| auto indices = torch::repeat_interleave( | |||||
| torch::arange(0, n_cols, 1, torch::TensorOptions() | |||||
| auto indices = torch::arange(0, n, 1, torch::TensorOptions() | |||||
| .dtype(torch::kInt64) | .dtype(torch::kInt64) | ||||
| .device(values.device())).view({ 1, -1 }), | |||||
| n_rows, | |||||
| 0 /* dim */ | |||||
| ); | |||||
| .device(values.device())).view({ n_rows, n_cols }); | |||||
| assert(indices.stride(-2) == n_cols); | assert(indices.stride(-2) == n_cols); | ||||
| assert(indices.stride(-1) == 1); | assert(indices.stride(-1) == 1); | ||||
| auto indices_ptr = indices.data_ptr<int64_t>(); | |||||
| auto n = n_rows * n_cols; | |||||
| auto indices_ptr = indices.data_ptr<int64_t>(); | |||||
| auto ind_beg = thrust::device_pointer_cast(indices_ptr); | auto ind_beg = thrust::device_pointer_cast(indices_ptr); | ||||
| auto val_beg = thrust::device_pointer_cast(values_ptr); | auto val_beg = thrust::device_pointer_cast(values_ptr); | ||||
| @@ -74,15 +70,18 @@ struct stable_sort_impl_cuda { | |||||
| thrust::transform(thrust::device, | thrust::transform(thrust::device, | ||||
| ind_beg, ind_beg + n, n_cols_iter, | ind_beg, ind_beg + n, n_cols_iter, | ||||
| segments.begin(), thrust::modulus<int64_t>()); | |||||
| segments.begin(), thrust::divides<int64_t>()); | |||||
| thrust::stable_sort_by_key(thrust::device, segments.begin(), | thrust::stable_sort_by_key(thrust::device, segments.begin(), | ||||
| segments.end(), ind_beg); | segments.end(), ind_beg); | ||||
| thrust::transform(thrust::device, ind_beg, ind_beg + n, | |||||
| n_cols_iter, ind_beg, thrust::modulus<int64_t>()); | |||||
| cudaDeviceSynchronize(); | cudaDeviceSynchronize(); | ||||
| values = values.view(values_sizes); | |||||
| indices = indices.view(values_sizes); | |||||
| values = values.view(orig_sizes); | |||||
| indices = indices.view(orig_sizes); | |||||
| if (dim != -1) | if (dim != -1) | ||||
| values = torch::transpose(values, dim, -1).contiguous(); | values = torch::transpose(values, dim, -1).contiguous(); | ||||