diff --git a/.gitignore b/.gitignore index 9ec7cab..63adf45 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ __pycache__ /docs/icosagon/*.png /experiments/decagon_run/profiler_results /experiments/decagon_run_effcat/profiler_results +/src/torch_stablesort/dist +/src/torch_stablesort/build +/src/torch_stablesort/torch_stablesort.egg-info +a.out diff --git a/src/torch_stablesort/dispatch.h b/src/torch_stablesort/dispatch.h new file mode 100644 index 0000000..1dc9b49 --- /dev/null +++ b/src/torch_stablesort/dispatch.h @@ -0,0 +1,41 @@ +#include + +template class F, typename R, typename... Ts> +R dispatch(torch::Tensor input, Ts&& ... args) { + switch(input.type().scalarType()) { + case torch::ScalarType::Double: + return F()(input, std::forward(args)...); + case torch::ScalarType::Float: + return F()(input, std::forward(args)...); + case torch::ScalarType::Half: + throw std::runtime_error("Half-precision float not supported"); + case torch::ScalarType::ComplexHalf: + throw std::runtime_error("Half-precision complex float not supported"); + case torch::ScalarType::ComplexFloat: + throw std::runtime_error("Complex float not supported"); + case torch::ScalarType::ComplexDouble: + throw std::runtime_error("Complex double not supported"); + case torch::ScalarType::Long: + return F()(input, std::forward(args)...); + case torch::ScalarType::Int: + return F()(input, std::forward(args)...); + case torch::ScalarType::Short: + return F()(input, std::forward(args)...); + case torch::ScalarType::Char: + return F()(input, std::forward(args)...); + case torch::ScalarType::Byte: + return F()(input, std::forward(args)...); + case torch::ScalarType::Bool: + return F()(input, std::forward(args)...); + case torch::ScalarType::QInt32: + throw std::runtime_error("QInt32 not supported"); + //case torch::ScalarType::QInt16: + // throw std::runtime_error("QInt16 not supported"); + case torch::ScalarType::QInt8: + throw std::runtime_error("QInt8 not supported"); + case torch::ScalarType::BFloat16: + throw std::runtime_error("BFloat16 not supported"); + default: + throw std::runtime_error("Unknown scalar type"); + } +} diff --git a/src/torch_stablesort/dispatch_test.cpp b/src/torch_stablesort/dispatch_test.cpp index 5756088..76e043f 100644 --- a/src/torch_stablesort/dispatch_test.cpp +++ b/src/torch_stablesort/dispatch_test.cpp @@ -12,14 +12,20 @@ void dispatch(int x, Ts&& ...args) { template struct bla { void operator()(int&& a, char&& b, double&& c) const { - std::cout << typeid(T).name() << " " << a << " " << b << " " << c << std::endl; + std::cout << sizeof(T) << " " << typeid(T).name() << " " << a << " " << b << " " << c << std::endl; } }; +template +struct bla128 { + void operator()(int&& a, char&& b, __float128&& c) const { + std::cout << sizeof(T) << " " << typeid(T).name() << " " << a << " " << b << " " << (double) c << std::endl; + } +}; main() { std::cout << "main()" << std::endl; //bla()(1, 'a', 5.5); - dispatch(5, 1, 'a', 5.5); + dispatch(5, 1, 'a', (__float128) 5.5); dispatch(-5, 1, 'a', 5.5); } diff --git a/src/torch_stablesort/openmp_test.cpp b/src/torch_stablesort/openmp_test.cpp new file mode 100644 index 0000000..10b364e --- /dev/null +++ b/src/torch_stablesort/openmp_test.cpp @@ -0,0 +1,8 @@ +#include + +main() { + #pragma omp parallel for + for (int i = 0; i < 10; i++) { + std::cout << i << std::endl; + } +} diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index 3071c80..163d8ff 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -3,5 +3,6 @@ from torch.utils import cpp_extension setup(name='torch_stablesort', ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', - ['torch_stablesort.cpp'])], + ['torch_stablesort.cpp'], + extra_compile_args=['-fopenmp', '-ggdb'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort.cpp b/src/torch_stablesort/torch_stablesort.cpp index d9cd856..94b3e13 100644 --- a/src/torch_stablesort/torch_stablesort.cpp +++ b/src/torch_stablesort/torch_stablesort.cpp @@ -4,137 +4,136 @@ #include #include +#include "dispatch.h" -template -void dispatch(torch::Tensor input, Ts&& ... args) { - switch(input.type().scalarType()) { - case torch::ScalarType::Double: - return fun(input, std::forward(args)...); - case torch::ScalarType::Float: - return fun(input, std::forward(args)...); - case torch::ScalarType::Half: - throw std::runtime_error("Half-precision float not supported"); - case torch::ScalarType::ComplexHalf: - throw std::runtime_error("Half-precision complex float not supported"); - case torch::ScalarType::ComplexFloat: - return fun(input, std::forward(args)...); - case torch::ScalarType::ComplexDouble: - return fun(input, std::forward(args)...); - case torch::ScalarType::Long: - return fun(input, std::forward(args)...); - case torch::ScalarType::Int: - return fun(input, std::forward(args)...); - case torch::ScalarType::Short: - return fun(input, std::forward(args)...); - case torch::ScalarType::Char: - return fun(input, std::forward(args)...); - case torch::ScalarType::Byte: - return fun(input, std::forward(args)...); - case torch::ScalarType::Bool: - return fun(input, std::forward(args)...); - case torch::ScalarType::QInt32: - throw std::runtime_error("QInt32 not supported"); - case torch::ScalarType::QInt16: - throw std::runtime_error("QInt16 not supported"); - case torch::ScalarType::QInt8: - throw std::runtime_error("QInt8 not supported"); - case torch::ScalarType::BFloat16: - throw std::runtime_error("BFloat16 not supported"); - default: - throw std::runtime_error("Unknown scalar type"); - } -} +template +struct stable_sort_impl { + std::vector operator()( + torch::Tensor input, + int dim, + bool descending, + torch::optional> out + ) const { + if (input.is_sparse()) + throw std::runtime_error("Sparse tensors are not supported"); -std::vector stable_sort_forward( - torch::Tensor input, - int dim, - bool descending, - torch::optional out = nullptr) { + if (input.device().type() != torch::DeviceType::CPU) + throw std::runtime_error("Only CPU tensors are supported"); + if (out != torch::nullopt) + throw std::runtime_error("out argument is not supported"); + auto in = (dim != -1) ? + torch::transpose(input, dim, -1) : + input; - auto X = torch::cat({old_h, input}, /*dim=*/1); + auto in_sizes = in.sizes(); - auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); - auto gates = gate_weights.chunk(3, /*dim=*/1); + // std::cout << "in_sizes: " << in_sizes << std::endl; - auto input_gate = torch::sigmoid(gates[0]); - auto output_gate = torch::sigmoid(gates[1]); - auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); + in = in.view({ -1, in.size(-1) }).contiguous(); - auto new_cell = old_cell + candidate_cell * input_gate; - auto new_h = torch::tanh(new_cell) * output_gate; + auto in_outer_stride = in.stride(-2); + auto in_inner_stride = in.stride(-1); - return {new_h, - new_cell, - input_gate, - output_gate, - candidate_cell, - X, - gate_weights}; -} + auto pin = static_cast(in.data_ptr()); -/ tanh'(z) = 1 - tanh^2(z) -torch::Tensor d_tanh(torch::Tensor z) { - return 1 - z.tanh().pow(2); -} + auto x = in.clone(); -// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} -torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { - auto e = z.exp(); - auto mask = (alpha * (e - 1)) < 0; - return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); -} + auto x_outer_stride = x.stride(-2); + auto x_inner_stride = x.stride(-1); -std::vector stable_sort_backward( - torch::Tensor grad_h, - torch::Tensor grad_cell, - torch::Tensor new_cell, - torch::Tensor input_gate, - torch::Tensor output_gate, - torch::Tensor candidate_cell, - torch::Tensor X, - torch::Tensor gate_weights, - torch::Tensor weights) { - auto d_output_gate = torch::tanh(new_cell) * grad_h; - auto d_tanh_new_cell = output_gate * grad_h; - auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; - - auto d_old_cell = d_new_cell; - auto d_candidate_cell = input_gate * d_new_cell; - auto d_input_gate = candidate_cell * d_new_cell; - - auto gates = gate_weights.chunk(3, /*dim=*/1); - d_input_gate *= d_sigmoid(gates[0]); - d_output_gate *= d_sigmoid(gates[1]); - d_candidate_cell *= d_elu(gates[2]); - - auto d_gates = - torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); - - auto d_weights = d_gates.t().mm(X); - auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); - - auto d_X = d_gates.mm(weights); - const auto state_size = grad_h.size(1); - auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); - auto d_input = d_X.slice(/*dim=*/1, state_size); - - return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; -} + auto n_cols = x.size(1); + auto n_rows = x.size(0); + auto px = static_cast(x.data_ptr()); -std::vector stable_argsort_forward() { + auto y = torch::empty({ n_rows, n_cols }, + torch::TensorOptions().dtype(torch::kInt64)); -} + auto y_outer_stride = y.stride(-2); + auto y_inner_stride = y.stride(-1); + + auto py = static_cast(y.data_ptr()); + + if (descending) { + #pragma omp parallel for + for (decltype(n_rows) i = 0; i < n_rows; i++) { + + std::vector indices(n_cols); + for (decltype(n_cols) k = 0; k < n_cols; k++) { + indices[k] = k; + } + + std::stable_sort(std::begin(indices), std::end(indices), + [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { + auto va = pin[i * in_outer_stride + a * in_inner_stride]; + auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + return (vb < va); + }); + + for (decltype(n_cols) k = 0; k < n_cols; k++) { + py[i * y_outer_stride + k * y_inner_stride] = indices[k]; + px[i * x_outer_stride + k * x_inner_stride] = + pin[i * in_outer_stride + indices[k] * in_inner_stride]; + } + } -std::vector stable_argsort_backward() { + } else { + #pragma omp parallel for + for (decltype(n_rows) i = 0; i < n_rows; i++) { + + std::vector indices(n_cols); + for (decltype(n_cols) k = 0; k < n_cols; k++) { + indices[k] = k; + } + + std::stable_sort(std::begin(indices), std::end(indices), + [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) { + auto va = pin[i * in_outer_stride + a * in_inner_stride]; + auto vb = pin[i * in_outer_stride + b * in_inner_stride]; + return (va < vb); + }); + + for (decltype(n_cols) k = 0; k < n_cols; k++) { + py[i * y_outer_stride + k * y_inner_stride] = indices[k]; + px[i * x_outer_stride + k * x_inner_stride] = + pin[i * in_outer_stride + indices[k] * in_inner_stride]; + } + } + } + + // std::cout << "Here" << std::endl; + + x = x.view(in_sizes); + y = y.view(in_sizes); + + x = (dim == -1) ? + x : + torch::transpose(x, dim, -1).contiguous(); + + y = (dim == -1) ? + y : + torch::transpose(y, dim, -1).contiguous(); + + // std::cout << "Here 2" << std::endl; + + return { x, y }; + } +}; + +std::vector stable_sort( + torch::Tensor input, + int dim = -1, + bool descending = false, + torch::optional> out = torch::nullopt) { + return dispatch>( + input, dim, descending, out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward"); - m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward"); - m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward"); - m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward"); + m.def("stable_sort", &stable_sort, "Stable sort", + py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false, + py::arg("out") = nullptr); }