IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

42 lines
1.8KB

  1. #include <utility>
  2. template<template<typename T> class F, typename R, typename... Ts>
  3. R dispatch(torch::Tensor input, Ts&& ... args) {
  4. switch(input.type().scalarType()) {
  5. case torch::ScalarType::Double:
  6. return F<double>()(input, std::forward<Ts>(args)...);
  7. case torch::ScalarType::Float:
  8. return F<float>()(input, std::forward<Ts>(args)...);
  9. case torch::ScalarType::Half:
  10. throw std::runtime_error("Half-precision float not supported");
  11. case torch::ScalarType::ComplexHalf:
  12. throw std::runtime_error("Half-precision complex float not supported");
  13. case torch::ScalarType::ComplexFloat:
  14. throw std::runtime_error("Complex float not supported");
  15. case torch::ScalarType::ComplexDouble:
  16. throw std::runtime_error("Complex double not supported");
  17. case torch::ScalarType::Long:
  18. return F<int64_t>()(input, std::forward<Ts>(args)...);
  19. case torch::ScalarType::Int:
  20. return F<int32_t>()(input, std::forward<Ts>(args)...);
  21. case torch::ScalarType::Short:
  22. return F<int16_t>()(input, std::forward<Ts>(args)...);
  23. case torch::ScalarType::Char:
  24. return F<int8_t>()(input, std::forward<Ts>(args)...);
  25. case torch::ScalarType::Byte:
  26. return F<uint8_t>()(input, std::forward<Ts>(args)...);
  27. case torch::ScalarType::Bool:
  28. return F<bool>()(input, std::forward<Ts>(args)...);
  29. case torch::ScalarType::QInt32:
  30. throw std::runtime_error("QInt32 not supported");
  31. //case torch::ScalarType::QInt16:
  32. // throw std::runtime_error("QInt16 not supported");
  33. case torch::ScalarType::QInt8:
  34. throw std::runtime_error("QInt8 not supported");
  35. case torch::ScalarType::BFloat16:
  36. throw std::runtime_error("BFloat16 not supported");
  37. default:
  38. throw std::runtime_error("Unknown scalar type");
  39. }
  40. }