@@ -1,3 +1,5 @@ | |||||
#pragma once | |||||
#include <utility> | #include <utility> | ||||
template<template<typename T> class F, typename R, typename... Ts> | template<template<typename T> class F, typename R, typename... Ts> | ||||
@@ -5,5 +5,6 @@ setup(name='torch_stablesort', | |||||
py_modules=['torch_stablesort'], | py_modules=['torch_stablesort'], | ||||
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ||||
['torch_stablesort.cpp'], | ['torch_stablesort.cpp'], | ||||
extra_compile_args=['-fopenmp', '-ggdb', '-std=c++1z'])], | |||||
cmdclass={'build_ext': cpp_extension.BuildExtension}) | |||||
extra_compile_args=['-I/pstore/home/adaszews/scratch/thrust', | |||||
'-fopenmp', '-ggdb', '-std=c++1z'])], | |||||
cmdclass={'build_ext': cpp_extension.BuildExtension}) |
@@ -4,105 +4,8 @@ | |||||
#include <vector> | #include <vector> | ||||
#include <algorithm> | #include <algorithm> | ||||
#include "dispatch.h" | |||||
template<bool descending, typename T> | |||||
struct stable_sort_impl { | |||||
std::vector<torch::Tensor> operator()( | |||||
torch::Tensor input, | |||||
int dim, | |||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
) const { | |||||
if (input.is_sparse()) | |||||
throw std::runtime_error("Sparse tensors are not supported"); | |||||
if (input.device().type() != torch::DeviceType::CPU) | |||||
throw std::runtime_error("Only CPU tensors are supported"); | |||||
if (out != torch::nullopt) | |||||
throw std::runtime_error("out argument is not supported"); | |||||
auto in = (dim != -1) ? | |||||
torch::transpose(input, dim, -1) : | |||||
input; | |||||
auto in_sizes = in.sizes(); | |||||
// std::cout << "in_sizes: " << in_sizes << std::endl; | |||||
in = in.view({ -1, in.size(-1) }).contiguous(); | |||||
auto in_outer_stride = in.stride(-2); | |||||
auto in_inner_stride = in.stride(-1); | |||||
auto pin = static_cast<T*>(in.data_ptr()); | |||||
auto x = in.clone(); | |||||
auto x_outer_stride = x.stride(-2); | |||||
auto x_inner_stride = x.stride(-1); | |||||
auto n_cols = x.size(1); | |||||
auto n_rows = x.size(0); | |||||
auto px = static_cast<T*>(x.data_ptr()); | |||||
auto y = torch::empty({ n_rows, n_cols }, | |||||
torch::TensorOptions().dtype(torch::kInt64)); | |||||
auto y_outer_stride = y.stride(-2); | |||||
auto y_inner_stride = y.stride(-1); | |||||
auto py = static_cast<int64_t*>(y.data_ptr()); | |||||
#pragma omp parallel for | |||||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
std::vector<int64_t> indices(n_cols); | |||||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
indices[k] = k; | |||||
} | |||||
std::stable_sort(std::begin(indices), std::end(indices), | |||||
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||||
auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||||
auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||||
if constexpr(descending) | |||||
return (vb < va); | |||||
else | |||||
return (va < vb); | |||||
}); | |||||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||||
px[i * x_outer_stride + k * x_inner_stride] = | |||||
pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||||
} | |||||
} | |||||
// std::cout << "Here" << std::endl; | |||||
x = x.view(in_sizes); | |||||
y = y.view(in_sizes); | |||||
x = (dim == -1) ? | |||||
x : | |||||
torch::transpose(x, dim, -1).contiguous(); | |||||
y = (dim == -1) ? | |||||
y : | |||||
torch::transpose(y, dim, -1).contiguous(); | |||||
// std::cout << "Here 2" << std::endl; | |||||
return { x, y }; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||||
template <typename T> | |||||
struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||||
#include "torch_stablesort_cuda.h" | |||||
#include "torch_stablesort_cpu.h" | |||||
std::vector<torch::Tensor> stable_sort( | std::vector<torch::Tensor> stable_sort( | ||||
torch::Tensor input, | torch::Tensor input, | ||||
@@ -110,12 +13,14 @@ std::vector<torch::Tensor> stable_sort( | |||||
bool descending = false, | bool descending = false, | ||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) { | ||||
if (descending) | |||||
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
else | |||||
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
switch (input.device().type()) { | |||||
case torch::DeviceType::CUDA: | |||||
return dispatch_cuda(input, dim, descending, out); | |||||
case torch::DeviceType::CPU: | |||||
return dispatch_cpu(input, dim, descending, out); | |||||
default: | |||||
throw std::runtime_error("Unsupported device type"); | |||||
} | |||||
} | } | ||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
@@ -0,0 +1,119 @@ | |||||
#pragma once | |||||
#include <torch/extension.h> | |||||
#include <vector> | |||||
#include <tuple> | |||||
#include "dispatch.h" | |||||
template<bool descending, typename T> | |||||
struct stable_sort_impl { | |||||
std::vector<torch::Tensor> operator()( | |||||
torch::Tensor input, | |||||
int dim, | |||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
) const { | |||||
if (input.is_sparse()) | |||||
throw std::runtime_error("Sparse tensors are not supported"); | |||||
if (input.device().type() != torch::DeviceType::CPU) | |||||
throw std::runtime_error("Only CPU tensors are supported"); | |||||
if (out != torch::nullopt) | |||||
throw std::runtime_error("out argument is not supported"); | |||||
auto in = (dim != -1) ? | |||||
torch::transpose(input, dim, -1) : | |||||
input; | |||||
auto in_sizes = in.sizes(); | |||||
// std::cout << "in_sizes: " << in_sizes << std::endl; | |||||
in = in.view({ -1, in.size(-1) }).contiguous(); | |||||
auto in_outer_stride = in.stride(-2); | |||||
auto in_inner_stride = in.stride(-1); | |||||
auto pin = static_cast<T*>(in.data_ptr()); | |||||
auto x = in.clone(); | |||||
auto x_outer_stride = x.stride(-2); | |||||
auto x_inner_stride = x.stride(-1); | |||||
auto n_cols = x.size(1); | |||||
auto n_rows = x.size(0); | |||||
auto px = static_cast<T*>(x.data_ptr()); | |||||
auto y = torch::empty({ n_rows, n_cols }, | |||||
torch::TensorOptions().dtype(torch::kInt64)); | |||||
auto y_outer_stride = y.stride(-2); | |||||
auto y_inner_stride = y.stride(-1); | |||||
auto py = static_cast<int64_t*>(y.data_ptr()); | |||||
#pragma omp parallel for | |||||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
std::vector<int64_t> indices(n_cols); | |||||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
indices[k] = k; | |||||
} | |||||
std::stable_sort(std::begin(indices), std::end(indices), | |||||
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { | |||||
auto va = pin[i * in_outer_stride + a * in_inner_stride]; | |||||
auto vb = pin[i * in_outer_stride + b * in_inner_stride]; | |||||
if constexpr(descending) | |||||
return (vb < va); | |||||
else | |||||
return (va < vb); | |||||
}); | |||||
for (decltype(n_cols) k = 0; k < n_cols; k++) { | |||||
py[i * y_outer_stride + k * y_inner_stride] = indices[k]; | |||||
px[i * x_outer_stride + k * x_inner_stride] = | |||||
pin[i * in_outer_stride + indices[k] * in_inner_stride]; | |||||
} | |||||
} | |||||
// std::cout << "Here" << std::endl; | |||||
x = x.view(in_sizes); | |||||
y = y.view(in_sizes); | |||||
x = (dim == -1) ? | |||||
x : | |||||
torch::transpose(x, dim, -1).contiguous(); | |||||
y = (dim == -1) ? | |||||
y : | |||||
torch::transpose(y, dim, -1).contiguous(); | |||||
// std::cout << "Here 2" << std::endl; | |||||
return { x, y }; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct stable_sort_impl_desc: stable_sort_impl<true, T> {}; | |||||
template <typename T> | |||||
struct stable_sort_impl_asc: stable_sort_impl<false, T> {}; | |||||
std::vector<torch::Tensor> dispatch_cpu(torch::Tensor input, | |||||
int dim, | |||||
bool descending, | |||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||||
if (descending) | |||||
return dispatch<stable_sort_impl_desc, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
else | |||||
return dispatch<stable_sort_impl_asc, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
} |
@@ -0,0 +1,109 @@ | |||||
#pragma once | |||||
#include <torch/extension.h> | |||||
#include <thrust/sort.h> | |||||
#include <thrust/device_ptr.h> | |||||
#include <thrust/execution_policy.h> | |||||
#include <vector> | |||||
#include <tuple> | |||||
#include "dispatch.h" | |||||
template<bool descending, typename T> | |||||
struct stable_sort_impl_cuda { | |||||
std::vector<torch::Tensor> operator()( | |||||
torch::Tensor input, | |||||
int dim, | |||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out | |||||
) const { | |||||
if (input.is_sparse()) | |||||
throw std::runtime_error("Sparse tensors are not supported"); | |||||
if (input.device().type() != torch::DeviceType::CUDA) | |||||
throw std::runtime_error("Only CUDA tensors are supported"); | |||||
if (out != torch::nullopt) | |||||
throw std::runtime_error("out argument is not supported"); | |||||
auto x = input.clone(); | |||||
if (dim != -1) | |||||
x = torch::transpose(x, dim, -1); | |||||
auto x_sizes = x.sizes(); | |||||
x = x.view({ -1, x.size(-1) }).contiguous(); | |||||
auto x_outer_stride = x.stride(-2); | |||||
auto x_inner_stride = x.stride(-1); | |||||
auto n_cols = x.size(1); | |||||
auto n_rows = x.size(0); | |||||
auto px = x.data_ptr<T>(); | |||||
assert(x_inner_stride == 1); | |||||
auto y = torch::repeat_interleave( | |||||
torch::arange(0, n_cols, 1, torch::TensorOptions() | |||||
.dtype(torch::kInt32) | |||||
.device(x.device())), | |||||
torch::ones(n_rows, torch::TensorOptions() | |||||
.dtype(torch::kInt32) | |||||
.device(x.device())) | |||||
); | |||||
auto y_outer_stride = y.stride(-2); | |||||
auto y_inner_stride = y.stride(-1); | |||||
auto py = y.data_ptr<int32_t>(); | |||||
assert(y_inner_stride == 1); | |||||
for (decltype(n_rows) i = 0; i < n_rows; i++) { | |||||
auto ind_beg = thrust::device_pointer_cast(py + i * y_outer_stride); | |||||
auto val_beg = thrust::device_pointer_cast(px + i * x_outer_stride); | |||||
auto val_end = thrust::device_pointer_cast(px + i * x_outer_stride + | |||||
n_cols * x_inner_stride); | |||||
if constexpr(descending) | |||||
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg, | |||||
thrust::greater<T>()); | |||||
else | |||||
thrust::stable_sort_by_key(thrust::device, val_beg, val_end, ind_beg); | |||||
} | |||||
x = x.view(x_sizes); | |||||
y = y.view(x_sizes); | |||||
x = (dim == -1) ? | |||||
x : | |||||
torch::transpose(x, dim, -1).contiguous(); | |||||
y = (dim == -1) ? | |||||
y : | |||||
torch::transpose(y, dim, -1).contiguous(); | |||||
return { x, y }; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda<true, T> {}; | |||||
template <typename T> | |||||
struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda<false, T> {}; | |||||
std::vector<torch::Tensor> dispatch_cuda(torch::Tensor input, | |||||
int dim, | |||||
bool descending, | |||||
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out) { | |||||
if (descending) | |||||
return dispatch<stable_sort_impl_desc_cuda, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
else | |||||
return dispatch<stable_sort_impl_asc_cuda, std::vector<torch::Tensor>>( | |||||
input, dim, out); | |||||
} |