IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.9KB

  1. #pragma once
  2. #include <torch/extension.h>
  3. #include <utility>
  4. template<template<typename T> class F, typename R, typename... Ts>
  5. R dispatch(torch::Tensor input, Ts&& ... args) {
  6. switch(input.type().scalarType()) {
  7. case torch::ScalarType::Double:
  8. return F<double>()(input, std::forward<Ts>(args)...);
  9. case torch::ScalarType::Float:
  10. return F<float>()(input, std::forward<Ts>(args)...);
  11. case torch::ScalarType::Half:
  12. throw std::runtime_error("Half-precision float not supported");
  13. case torch::ScalarType::ComplexHalf:
  14. throw std::runtime_error("Half-precision complex float not supported");
  15. case torch::ScalarType::ComplexFloat:
  16. throw std::runtime_error("Complex float not supported");
  17. case torch::ScalarType::ComplexDouble:
  18. throw std::runtime_error("Complex double not supported");
  19. case torch::ScalarType::Long:
  20. return F<int64_t>()(input, std::forward<Ts>(args)...);
  21. case torch::ScalarType::Int:
  22. return F<int32_t>()(input, std::forward<Ts>(args)...);
  23. case torch::ScalarType::Short:
  24. return F<int16_t>()(input, std::forward<Ts>(args)...);
  25. case torch::ScalarType::Char:
  26. return F<int8_t>()(input, std::forward<Ts>(args)...);
  27. case torch::ScalarType::Byte:
  28. return F<uint8_t>()(input, std::forward<Ts>(args)...);
  29. case torch::ScalarType::Bool:
  30. return F<bool>()(input, std::forward<Ts>(args)...);
  31. case torch::ScalarType::QInt32:
  32. throw std::runtime_error("QInt32 not supported");
  33. //case torch::ScalarType::QInt16:
  34. // throw std::runtime_error("QInt16 not supported");
  35. case torch::ScalarType::QInt8:
  36. throw std::runtime_error("QInt8 not supported");
  37. case torch::ScalarType::BFloat16:
  38. throw std::runtime_error("BFloat16 not supported");
  39. default:
  40. throw std::runtime_error("Unknown scalar type");
  41. }
  42. }