#include template class F, typename R, typename... Ts> R dispatch(torch::Tensor input, Ts&& ... args) { switch(input.type().scalarType()) { case torch::ScalarType::Double: return F()(input, std::forward(args)...); case torch::ScalarType::Float: return F()(input, std::forward(args)...); case torch::ScalarType::Half: throw std::runtime_error("Half-precision float not supported"); case torch::ScalarType::ComplexHalf: throw std::runtime_error("Half-precision complex float not supported"); case torch::ScalarType::ComplexFloat: throw std::runtime_error("Complex float not supported"); case torch::ScalarType::ComplexDouble: throw std::runtime_error("Complex double not supported"); case torch::ScalarType::Long: return F()(input, std::forward(args)...); case torch::ScalarType::Int: return F()(input, std::forward(args)...); case torch::ScalarType::Short: return F()(input, std::forward(args)...); case torch::ScalarType::Char: return F()(input, std::forward(args)...); case torch::ScalarType::Byte: return F()(input, std::forward(args)...); case torch::ScalarType::Bool: return F()(input, std::forward(args)...); case torch::ScalarType::QInt32: throw std::runtime_error("QInt32 not supported"); //case torch::ScalarType::QInt16: // throw std::runtime_error("QInt16 not supported"); case torch::ScalarType::QInt8: throw std::runtime_error("QInt8 not supported"); case torch::ScalarType::BFloat16: throw std::runtime_error("BFloat16 not supported"); default: throw std::runtime_error("Unknown scalar type"); } }