IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

CUDA performance needs improvement.

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
c1353168f3
2 fichiers modifiés avec 32 ajouts et 20 suppressions
  1. +2
    -1
      src/torch_stablesort/setup.py
  2. +30
    -19
      src/torch_stablesort/torch_stablesort_cuda.cu

+ 2
- 1
src/torch_stablesort/setup.py Voir le fichier

@@ -8,6 +8,7 @@ setup(name='torch_stablesort',
extra_compile_args={
'cxx': ['-fopenmp', '-ggdb', '-std=c++1z'],
'nvcc': [ '-I/pstore/home/adaszews/scratch/thrust',
'-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc', '-std=c++14']
'-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc',
'-std=c++14', '--expt-extended-lambda', '-O99']
} ) ],
cmdclass={'build_ext': cpp_extension.BuildExtension})

+ 30
- 19
src/torch_stablesort/torch_stablesort_cuda.cu Voir le fichier

@@ -43,32 +43,43 @@ struct stable_sort_impl_cuda {
auto y = torch::repeat_interleave(
torch::arange(0, n_cols, 1, torch::TensorOptions()
.dtype(torch::kInt32)
.device(x.device())),
torch::ones(n_rows, torch::TensorOptions()
.dtype(torch::kInt32)
.device(x.device()))
.dtype(torch::kInt64)
.device(x.device())).view({ 1, -1 }),
n_rows,
0 /* dim */
);
auto y_outer_stride = y.stride(-2);
auto y_inner_stride = y.stride(-1);
auto py = y.data_ptr<int32_t>();
auto py = y.data_ptr<int64_t>();
assert(y_inner_stride == 1);
for (decltype(n_rows) i = 0; i < n_rows; i++) {
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
n_cols * x_inner_stride);
if constexpr(descending)
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg,
thrust::greater<T>());
else
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg);
}
#define NUM_STREAMS 16
cudaStream_t streams[NUM_STREAMS];
for(int i = 0; i < NUM_STREAMS; i++)
assert(cudaStreamCreate(&streams[i]) == cudaSuccess);
thrust::host_vector<int64_t> row_indices(n_rows);
thrust::sequence(row_indices.begin(), row_indices.end());
thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(),
[&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) {
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
n_cols * x_inner_stride);
if (descending)
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg,
thrust::greater<T>());
else
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg);
});
cudaDeviceSynchronize();
for(int i = 0; i < NUM_STREAMS; i++)
assert(cudaStreamDestroy(streams[i]) == cudaSuccess);
x = x.view(x_sizes);
y = y.view(x_sizes);


Chargement…
Annuler
Enregistrer