diff --git a/src/torch_stablesort/dispatch_test.cpp b/src/torch_stablesort/dispatch_test.cpp new file mode 100644 index 0000000..8079998 --- /dev/null +++ b/src/torch_stablesort/dispatch_test.cpp @@ -0,0 +1,33 @@ +#include +#include + +template +class Dispatcher { +public: + void dispatch(int x, Ts&& ...args) { + if (x > 0) + call(std::forward(args)...); + else + call(std::forward(args)...); + } + +protected: + template + void call(Ts&& ...args) { + throw std::runtime_error("Not implemented"); + } +}; + +class bla {}; + +template<> template +void Dispatcher::call(int&& a, char&& b, double&& c) { + std::cout << typeid(T).name() << " " << a << " " << b << " " << c << std::endl; +} + +main() { + std::cout << "main()" << std::endl; + Dispatcher d; + d.dispatch(5, 1, 'a', 5.5); + d.dispatch(-5, 1, 'a', 5.5); +} diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py new file mode 100644 index 0000000..3071c80 --- /dev/null +++ b/src/torch_stablesort/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='torch_stablesort', + ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', + ['torch_stablesort.cpp'])], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/src/torch_stablesort/torch_stablesort.cpp b/src/torch_stablesort/torch_stablesort.cpp new file mode 100644 index 0000000..d9cd856 --- /dev/null +++ b/src/torch_stablesort/torch_stablesort.cpp @@ -0,0 +1,140 @@ +#include + +#include +#include +#include + + +template +void dispatch(torch::Tensor input, Ts&& ... args) { + switch(input.type().scalarType()) { + case torch::ScalarType::Double: + return fun(input, std::forward(args)...); + case torch::ScalarType::Float: + return fun(input, std::forward(args)...); + case torch::ScalarType::Half: + throw std::runtime_error("Half-precision float not supported"); + case torch::ScalarType::ComplexHalf: + throw std::runtime_error("Half-precision complex float not supported"); + case torch::ScalarType::ComplexFloat: + return fun(input, std::forward(args)...); + case torch::ScalarType::ComplexDouble: + return fun(input, std::forward(args)...); + case torch::ScalarType::Long: + return fun(input, std::forward(args)...); + case torch::ScalarType::Int: + return fun(input, std::forward(args)...); + case torch::ScalarType::Short: + return fun(input, std::forward(args)...); + case torch::ScalarType::Char: + return fun(input, std::forward(args)...); + case torch::ScalarType::Byte: + return fun(input, std::forward(args)...); + case torch::ScalarType::Bool: + return fun(input, std::forward(args)...); + case torch::ScalarType::QInt32: + throw std::runtime_error("QInt32 not supported"); + case torch::ScalarType::QInt16: + throw std::runtime_error("QInt16 not supported"); + case torch::ScalarType::QInt8: + throw std::runtime_error("QInt8 not supported"); + case torch::ScalarType::BFloat16: + throw std::runtime_error("BFloat16 not supported"); + default: + throw std::runtime_error("Unknown scalar type"); + } +} + + +std::vector stable_sort_forward( + torch::Tensor input, + int dim, + bool descending, + torch::optional out = nullptr) { + + + + auto X = torch::cat({old_h, input}, /*dim=*/1); + + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + auto gates = gate_weights.chunk(3, /*dim=*/1); + + auto input_gate = torch::sigmoid(gates[0]); + auto output_gate = torch::sigmoid(gates[1]); + auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); + + auto new_cell = old_cell + candidate_cell * input_gate; + auto new_h = torch::tanh(new_cell) * output_gate; + + return {new_h, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights}; +} + +/ tanh'(z) = 1 - tanh^2(z) +torch::Tensor d_tanh(torch::Tensor z) { + return 1 - z.tanh().pow(2); +} + +// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} +torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { + auto e = z.exp(); + auto mask = (alpha * (e - 1)) < 0; + return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); +} + +std::vector stable_sort_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_output_gate = torch::tanh(new_cell) * grad_h; + auto d_tanh_new_cell = output_gate * grad_h; + auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; + + auto d_old_cell = d_new_cell; + auto d_candidate_cell = input_gate * d_new_cell; + auto d_input_gate = candidate_cell * d_new_cell; + + auto gates = gate_weights.chunk(3, /*dim=*/1); + d_input_gate *= d_sigmoid(gates[0]); + d_output_gate *= d_sigmoid(gates[1]); + d_candidate_cell *= d_elu(gates[2]); + + auto d_gates = + torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); + + auto d_weights = d_gates.t().mm(X); + auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gates.mm(weights); + const auto state_size = grad_h.size(1); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +std::vector stable_argsort_forward() { + +} + +std::vector stable_argsort_backward() { + +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward"); + m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward"); + m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward"); + m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward"); +}