IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

31 lines
865B

  1. #include <torch/extension.h>
  2. #include <iostream>
  3. #include <vector>
  4. #include <algorithm>
  5. #include "torch_stablesort_cuda.h"
  6. #include "torch_stablesort_cpu.h"
  7. std::vector<torch::Tensor> stable_sort(
  8. torch::Tensor input,
  9. int dim = -1,
  10. bool descending = false,
  11. torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
  12. switch (input.device().type()) {
  13. case torch::DeviceType::CUDA:
  14. return dispatch_cuda(input, dim, descending, out);
  15. case torch::DeviceType::CPU:
  16. return dispatch_cpu(input, dim, descending, out);
  17. default:
  18. throw std::runtime_error("Unsupported device type");
  19. }
  20. }
  21. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  22. m.def("stable_sort", &stable_sort, "Stable sort",
  23. py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false,
  24. py::arg("out") = nullptr);
  25. }