|
1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- #pragma once
-
- #include <torch/extension.h>
- #include <utility>
-
- template<template<typename T> class F, typename R, typename... Ts>
- R dispatch(torch::Tensor input, Ts&& ... args) {
- switch(input.type().scalarType()) {
- case torch::ScalarType::Double:
- return F<double>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Float:
- return F<float>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Half:
- throw std::runtime_error("Half-precision float not supported");
- case torch::ScalarType::ComplexHalf:
- throw std::runtime_error("Half-precision complex float not supported");
- case torch::ScalarType::ComplexFloat:
- throw std::runtime_error("Complex float not supported");
- case torch::ScalarType::ComplexDouble:
- throw std::runtime_error("Complex double not supported");
- case torch::ScalarType::Long:
- return F<int64_t>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Int:
- return F<int32_t>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Short:
- return F<int16_t>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Char:
- return F<int8_t>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Byte:
- return F<uint8_t>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Bool:
- return F<bool>()(input, std::forward<Ts>(args)...);
- case torch::ScalarType::QInt32:
- throw std::runtime_error("QInt32 not supported");
- //case torch::ScalarType::QInt16:
- // throw std::runtime_error("QInt16 not supported");
- case torch::ScalarType::QInt8:
- throw std::runtime_error("QInt8 not supported");
- case torch::ScalarType::BFloat16:
- throw std::runtime_error("BFloat16 not supported");
- default:
- throw std::runtime_error("Unknown scalar type");
- }
- }
|