IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1.8KB

  1. #pragma once
  2. #include <utility>
  3. template<template<typename T> class F, typename R, typename... Ts>
  4. R dispatch(torch::Tensor input, Ts&& ... args) {
  5. switch(input.type().scalarType()) {
  6. case torch::ScalarType::Double:
  7. return F<double>()(input, std::forward<Ts>(args)...);
  8. case torch::ScalarType::Float:
  9. return F<float>()(input, std::forward<Ts>(args)...);
  10. case torch::ScalarType::Half:
  11. throw std::runtime_error("Half-precision float not supported");
  12. case torch::ScalarType::ComplexHalf:
  13. throw std::runtime_error("Half-precision complex float not supported");
  14. case torch::ScalarType::ComplexFloat:
  15. throw std::runtime_error("Complex float not supported");
  16. case torch::ScalarType::ComplexDouble:
  17. throw std::runtime_error("Complex double not supported");
  18. case torch::ScalarType::Long:
  19. return F<int64_t>()(input, std::forward<Ts>(args)...);
  20. case torch::ScalarType::Int:
  21. return F<int32_t>()(input, std::forward<Ts>(args)...);
  22. case torch::ScalarType::Short:
  23. return F<int16_t>()(input, std::forward<Ts>(args)...);
  24. case torch::ScalarType::Char:
  25. return F<int8_t>()(input, std::forward<Ts>(args)...);
  26. case torch::ScalarType::Byte:
  27. return F<uint8_t>()(input, std::forward<Ts>(args)...);
  28. case torch::ScalarType::Bool:
  29. return F<bool>()(input, std::forward<Ts>(args)...);
  30. case torch::ScalarType::QInt32:
  31. throw std::runtime_error("QInt32 not supported");
  32. //case torch::ScalarType::QInt16:
  33. // throw std::runtime_error("QInt16 not supported");
  34. case torch::ScalarType::QInt8:
  35. throw std::runtime_error("QInt8 not supported");
  36. case torch::ScalarType::BFloat16:
  37. throw std::runtime_error("BFloat16 not supported");
  38. default:
  39. throw std::runtime_error("Unknown scalar type");
  40. }
  41. }