|
123456789101112131415161718192021222324252627282930 |
- #include <torch/extension.h>
-
- #include <iostream>
- #include <vector>
- #include <algorithm>
-
- #include "torch_stablesort_cuda.h"
- #include "torch_stablesort_cpu.h"
-
- std::vector<torch::Tensor> stable_sort(
- torch::Tensor input,
- int dim = -1,
- bool descending = false,
- torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
-
- switch (input.device().type()) {
- case torch::DeviceType::CUDA:
- return dispatch_cuda(input, dim, descending, out);
- case torch::DeviceType::CPU:
- return dispatch_cpu(input, dim, descending, out);
- default:
- throw std::runtime_error("Unsupported device type");
- }
- }
-
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("stable_sort", &stable_sort, "Stable sort",
- py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false,
- py::arg("out") = nullptr);
- }
|