#include #include #include #include #include "torch_stablesort_cuda.h" #include "torch_stablesort_cpu.h" std::vector stable_sort( torch::Tensor input, int dim = -1, bool descending = false, torch::optional> out = torch::nullopt) { switch (input.device().type()) { case torch::DeviceType::CUDA: return dispatch_cuda(input, dim, descending, out); case torch::DeviceType::CPU: return dispatch_cpu(input, dim, descending, out); default: throw std::runtime_error("Unsupported device type"); } } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("stable_sort", &stable_sort, "Stable sort", py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false, py::arg("out") = nullptr); }