IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

torch_stablesort_cuda.cu getting there...

master
Stanislaw Adaszewski 3 years ago
parent
commit
897a6f0722
1 changed files with 49 additions and 51 deletions
  1. +49
    -51
      src/torch_stablesort/torch_stablesort_cuda.cu

+ 49
- 51
src/torch_stablesort/torch_stablesort_cuda.cu View File

@@ -24,75 +24,73 @@ struct stable_sort_impl_cuda {
if (out != torch::nullopt)
throw std::runtime_error("out argument is not supported");
auto x = input.clone();
auto values = input.clone();
if (dim != -1)
x = torch::transpose(x, dim, -1);
values = torch::transpose(values, dim, -1);
auto x_sizes = x.sizes();
auto values_sizes = values.sizes();
x = x.view({ -1, x.size(-1) }).contiguous();
values = values.view({ -1, values.size(-1) }).contiguous();
auto x_outer_stride = x.stride(-2);
auto x_inner_stride = x.stride(-1);
auto n_cols = x.size(1);
auto n_rows = x.size(0);
auto px = x.data_ptr<T>();
auto n_cols = values.size(1);
auto n_rows = values.size(0);
assert(x_inner_stride == 1);
assert(values.stride(-2) == n_cols);
assert(values.stride(-1) == 1);
auto y = torch::repeat_interleave(
auto values_ptr = values.data_ptr<T>();
auto indices = torch::repeat_interleave(
torch::arange(0, n_cols, 1, torch::TensorOptions()
.dtype(torch::kInt64)
.device(x.device())).view({ 1, -1 }),
.device(values.device())).view({ 1, -1 }),
n_rows,
0 /* dim */
);
auto y_outer_stride = y.stride(-2);
auto y_inner_stride = y.stride(-1);
auto py = y.data_ptr<int64_t>();
assert(y_inner_stride == 1);
#define NUM_STREAMS 16
cudaStream_t streams[NUM_STREAMS];
for(int i = 0; i < NUM_STREAMS; i++)
assert(cudaStreamCreate(&streams[i]) == cudaSuccess);
thrust::host_vector<int64_t> row_indices(n_rows);
thrust::sequence(row_indices.begin(), row_indices.end());
thrust::for_each(thrust::host, row_indices.begin(), row_indices.end(),
[&streams, py, y_outer_stride, px, x_outer_stride, x_inner_stride, n_cols](int64_t i) {
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride);
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride);
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride +
n_cols * x_inner_stride);
if (descending)
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg,
thrust::greater<T>());
else
thrust::stable_sort_by_key(thrust::cuda::par.on(streams[i % NUM_STREAMS]), val_beg, val_end, ind_beg);
});
cudaDeviceSynchronize();
assert(indices.stride(-2) == n_cols);
assert(indices.stride(-1) == 1);
auto indices_ptr = indices.data_ptr<int64_t>();
auto n = n_rows * n_cols;
auto ind_beg = thrust::device_pointer_cast(indices_ptr);
auto val_beg = thrust::device_pointer_cast(values_ptr);
if (descending)
thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater<T>());
else
thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg);
thrust::device_vector<int64_t> segments(n);
thrust::constant_iterator<int64_t> n_cols_iter(n_cols);
thrust::transform(thrust::device,
ind_beg, ind_beg + n, n_cols_iter,
segments.begin(), thrust::divides<int64_t>());
for(int i = 0; i < NUM_STREAMS; i++)
assert(cudaStreamDestroy(streams[i]) == cudaSuccess);
thrust::stable_sort_by_key(thrust::device, segments.begin(),
segments.end(), val_beg);
x = x.view(x_sizes);
y = y.view(x_sizes);
thrust::transform(thrust::device,
ind_beg, ind_beg + n, n_cols_iter,
segments.begin(), thrust::modulus<int64_t>());
x = (dim == -1) ?
x :
torch::transpose(x, dim, -1).contiguous();
thrust::stable_sort_by_key(thrust::device, segments.begin(),
segments.end(), ind_beg);
y = (dim == -1) ?
y :
torch::transpose(y, dim, -1).contiguous();
cudaDeviceSynchronize();
values = values.view(values_sizes);
indices = indices.view(values_sizes);
if (dim != -1)
values = torch::transpose(values, dim, -1).contiguous();
if (dim != -1)
indices = torch::transpose(indices, dim, -1).contiguous();
return { x, y };
return { values, indices };
}
};


Loading…
Cancel
Save