@@ -1,3 +1,5 @@ | |||
#pragma once | |||
#include <utility> | |||
template<template<typename T> class F, typename R, typename... Ts> | |||
@@ -5,5 +5,6 @@ setup(name='torch_stablesort', | |||
py_modules=['torch_stablesort'], | |||
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | |||
['torch_stablesort.cpp'], | |||
extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], | |||
cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||
extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', | |||
'-fopenmp', '-ggdb', '-std=c++1z'])], | |||
cmdclass={'build_ext': cpp_extension.BuildExtension}) |
@@ -4,105 +4,8 @@ | |||
#include <vector> | |||
#include <algorithm> | |||
#include "dispatch.h" | |||
template<bool descending, typename T> | |||
struct stable_sort_impl { | |||
std::vector<torch::Tensor> operator()( | |||
torch::Tensor input, | |||
int dim, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
) const { | |||
if (input.is_sparse()) | |||
throw std::runtime_error("Sparse tensors are not supported"); | |||
if (input.device().type() != torch::DeviceType::CPU) | |||
throw std::runtime_error("Only CPU tensors are supported"); | |||
if (out != torch::nullopt) | |||
throw std::runtime_error("out argument is not supported"); | |||
auto in = (dim != -1) ? | |||
torch::transpose(input, dim, -1) : | |||
input; | |||
auto in_sizes = in.sizes(); | |||
// std::cout << "in_sizes: " << in_sizes << std::endl; | |||
in = in.view({ -1, in.size(-1) }).contiguous(); | |||
auto in_outer_stride = in.stride(-2); | |||
auto in_inner_stride = in.stride(-1); | |||
auto pin = static_cast<T*>(in.data_ptr()); | |||
auto x = in.clone(); | |||
auto x_outer_stride = x.stride(-2); | |||
auto x_inner_stride = x.stride(-1); | |||
auto n_cols = x.size(1); | |||
auto n_rows = x.size(0); | |||
auto px = static_cast<T*>(x.data_ptr()); | |||
auto y = torch::empty({ n_rows, n_cols }, | |||
torch::TensorOptions().dtype(torch::kInt64)); | |||
auto y_outer_stride = y.stride(-2); | |||
auto y_inner_stride = y.stride(-1); | |||
auto py = static_cast<int64_t*>(y.data_ptr()); | |||
#pragma omp parallel for | |||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
std::vector<int64_t> indices(n_cols); | |||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
indices[k] = k; | |||
} | |||
std::stable_sort(std::begin(indices), std::end(indices), | |||
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||
auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||
auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||
if constexpr(descending) | |||
return (vb < va); | |||
else | |||
return (va < vb); | |||
}); | |||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||
px[i * x_outer_stride + k * x_inner_stride] = | |||
pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||
} | |||
} | |||
// std::cout << "Here" << std::endl; | |||
x = x.view(in_sizes); | |||
y = y.view(in_sizes); | |||
x = (dim == -1) ? | |||
x : | |||
torch::transpose(x, dim, -1).contiguous(); | |||
y = (dim == -1) ? | |||
y : | |||
torch::transpose(y, dim, -1).contiguous(); | |||
// std::cout << "Here 2" << std::endl; | |||
return { x, y }; | |||
} | |||
}; | |||
template <typename T> | |||
struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||
template <typename T> | |||
struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||
#include "torch_stablesort_cuda.h" | |||
#include "torch_stablesort_cpu.h" | |||
std::vector<torch::Tensor> stable_sort( | |||
torch::Tensor input, | |||
@@ -110,12 +13,14 @@ std::vector<torch::Tensor> stable_sort( | |||
bool descending = false, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | |||
if (descending) | |||
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
else | |||
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
switch (input.device().type()) { | |||
case torch::DeviceType::CUDA: | |||
return dispatch_cuda(input, dim, descending, out); | |||
case torch::DeviceType::CPU: | |||
return dispatch_cpu(input, dim, descending, out); | |||
default: | |||
throw std::runtime_error("Unsupported device type"); | |||
} | |||
} | |||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||
@@ -0,0 +1,119 @@ | |||
#pragma once | |||
#include <torch/extension.h> | |||
#include <vector> | |||
#include <tuple> | |||
#include "dispatch.h" | |||
template<bool descending, typename T> | |||
struct stable_sort_impl { | |||
std::vector<torch::Tensor> operator()( | |||
torch::Tensor input, | |||
int dim, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
) const { | |||
if (input.is_sparse()) | |||
throw std::runtime_error("Sparse tensors are not supported"); | |||
if (input.device().type() != torch::DeviceType::CPU) | |||
throw std::runtime_error("Only CPU tensors are supported"); | |||
if (out != torch::nullopt) | |||
throw std::runtime_error("out argument is not supported"); | |||
auto in = (dim != -1) ? | |||
torch::transpose(input, dim, -1) : | |||
input; | |||
auto in_sizes = in.sizes(); | |||
// std::cout << "in_sizes: " << in_sizes << std::endl; | |||
in = in.view({ -1, in.size(-1) }).contiguous(); | |||
auto in_outer_stride = in.stride(-2); | |||
auto in_inner_stride = in.stride(-1); | |||
auto pin = static_cast<T*>(in.data_ptr()); | |||
auto x = in.clone(); | |||
auto x_outer_stride = x.stride(-2); | |||
auto x_inner_stride = x.stride(-1); | |||
auto n_cols = x.size(1); | |||
auto n_rows = x.size(0); | |||
auto px = static_cast<T*>(x.data_ptr()); | |||
auto y = torch::empty({ n_rows, n_cols }, | |||
torch::TensorOptions().dtype(torch::kInt64)); | |||
auto y_outer_stride = y.stride(-2); | |||
auto y_inner_stride = y.stride(-1); | |||
auto py = static_cast<int64_t*>(y.data_ptr()); | |||
#pragma omp parallel for | |||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
std::vector<int64_t> indices(n_cols); | |||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
indices[k] = k; | |||
} | |||
std::stable_sort(std::begin(indices), std::end(indices), | |||
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||
auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||
auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||
if constexpr(descending) | |||
return (vb < va); | |||
else | |||
return (va < vb); | |||
}); | |||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||
py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||
px[i * x_outer_stride + k * x_inner_stride] = | |||
pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||
} | |||
} | |||
// std::cout << "Here" << std::endl; | |||
x = x.view(in_sizes); | |||
y = y.view(in_sizes); | |||
x = (dim == -1) ? | |||
x : | |||
torch::transpose(x, dim, -1).contiguous(); | |||
y = (dim == -1) ? | |||
y : | |||
torch::transpose(y, dim, -1).contiguous(); | |||
// std::cout << "Here 2" << std::endl; | |||
return { x, y }; | |||
} | |||
}; | |||
template <typename T> | |||
struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||
template <typename T> | |||
struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||
std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input, | |||
int dim, | |||
bool descending, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||
if (descending) | |||
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
else | |||
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
} |
@@ -0,0 +1,109 @@ | |||
#pragma once | |||
#include <torch/extension.h> | |||
#include <thrust/sort.h> | |||
#include <thrust/device_ptr.h> | |||
#include <thrust/execution_policy.h> | |||
#include <vector> | |||
#include <tuple> | |||
#include "dispatch.h" | |||
template<bool descending, typename T> | |||
struct stable_sort_impl_cuda { | |||
std::vector<torch::Tensor> operator()( | |||
torch::Tensor input, | |||
int dim, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||
) const { | |||
if (input.is_sparse()) | |||
throw std::runtime_error("Sparse tensors are not supported"); | |||
if (input.device().type() != torch::DeviceType::CUDA) | |||
throw std::runtime_error("Only CUDA tensors are supported"); | |||
if (out != torch::nullopt) | |||
throw std::runtime_error("out argument is not supported"); | |||
auto x = input.clone(); | |||
if (dim != -1) | |||
x = torch::transpose(x, dim, -1); | |||
auto x_sizes = x.sizes(); | |||
x = x.view({ -1, x.size(-1) }).contiguous(); | |||
auto x_outer_stride = x.stride(-2); | |||
auto x_inner_stride = x.stride(-1); | |||
auto n_cols = x.size(1); | |||
auto n_rows = x.size(0); | |||
auto px = x.data_ptr<T>(); | |||
assert(x_inner_stride == 1); | |||
auto y = torch::repeat_interleave( | |||
torch::arange(0, n_cols, 1, torch::TensorOptions() | |||
.dtype(torch::kInt32) | |||
.device(x.device())), | |||
torch::ones(n_rows, torch::TensorOptions() | |||
.dtype(torch::kInt32) | |||
.device(x.device())) | |||
); | |||
auto y_outer_stride = y.stride(-2); | |||
auto y_inner_stride = y.stride(-1); | |||
auto py = y.data_ptr<int32_t>(); | |||
assert(y_inner_stride == 1); | |||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); | |||
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); | |||
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + | |||
n_cols * x_inner_stride); | |||
if constexpr(descending) | |||
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, | |||
thrust::greater<T>()); | |||
else | |||
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); | |||
} | |||
x = x.view(x_sizes); | |||
y = y.view(x_sizes); | |||
x = (dim == -1) ? | |||
x : | |||
torch::transpose(x, dim, -1).contiguous(); | |||
y = (dim == -1) ? | |||
y : | |||
torch::transpose(y, dim, -1).contiguous(); | |||
return { x, y }; | |||
} | |||
}; | |||
template <typename T> | |||
struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {}; | |||
template <typename T> | |||
struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {}; | |||
std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input, | |||
int dim, | |||
bool descending, | |||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||
if (descending) | |||
return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
else | |||
return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>( | |||
input, dim, out); | |||
} |