#include #include #include #include template void dispatch(torch::Tensor input, Ts&& ... args) { switch(input.type().scalarType()) { case torch::ScalarType::Double: return fun(input, std::forward(args)...); case torch::ScalarType::Float: return fun(input, std::forward(args)...); case torch::ScalarType::Half: throw std::runtime_error("Half-precision float not supported"); case torch::ScalarType::ComplexHalf: throw std::runtime_error("Half-precision complex float not supported"); case torch::ScalarType::ComplexFloat: return fun(input, std::forward(args)...); case torch::ScalarType::ComplexDouble: return fun(input, std::forward(args)...); case torch::ScalarType::Long: return fun(input, std::forward(args)...); case torch::ScalarType::Int: return fun(input, std::forward(args)...); case torch::ScalarType::Short: return fun(input, std::forward(args)...); case torch::ScalarType::Char: return fun(input, std::forward(args)...); case torch::ScalarType::Byte: return fun(input, std::forward(args)...); case torch::ScalarType::Bool: return fun(input, std::forward(args)...); case torch::ScalarType::QInt32: throw std::runtime_error("QInt32 not supported"); case torch::ScalarType::QInt16: throw std::runtime_error("QInt16 not supported"); case torch::ScalarType::QInt8: throw std::runtime_error("QInt8 not supported"); case torch::ScalarType::BFloat16: throw std::runtime_error("BFloat16 not supported"); default: throw std::runtime_error("Unknown scalar type"); } } std::vector stable_sort_forward( torch::Tensor input, int dim, bool descending, torch::optional out = nullptr) { auto X = torch::cat({old_h, input}, /*dim=*/1); auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); auto gates = gate_weights.chunk(3, /*dim=*/1); auto input_gate = torch::sigmoid(gates[0]); auto output_gate = torch::sigmoid(gates[1]); auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); auto new_cell = old_cell + candidate_cell * input_gate; auto new_h = torch::tanh(new_cell) * output_gate; return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights}; } / tanh'(z) = 1 - tanh^2(z) torch::Tensor d_tanh(torch::Tensor z) { return 1 - z.tanh().pow(2); } // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { auto e = z.exp(); auto mask = (alpha * (e - 1)) < 0; return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); } std::vector stable_sort_backward( torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, torch::Tensor input_gate, torch::Tensor output_gate, torch::Tensor candidate_cell, torch::Tensor X, torch::Tensor gate_weights, torch::Tensor weights) { auto d_output_gate = torch::tanh(new_cell) * grad_h; auto d_tanh_new_cell = output_gate * grad_h; auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; auto d_old_cell = d_new_cell; auto d_candidate_cell = input_gate * d_new_cell; auto d_input_gate = candidate_cell * d_new_cell; auto gates = gate_weights.chunk(3, /*dim=*/1); d_input_gate *= d_sigmoid(gates[0]); d_output_gate *= d_sigmoid(gates[1]); d_candidate_cell *= d_elu(gates[2]); auto d_gates = torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); auto d_weights = d_gates.t().mm(X); auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); auto d_X = d_gates.mm(weights); const auto state_size = grad_h.size(1); auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); auto d_input = d_X.slice(/*dim=*/1, state_size); return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; } std::vector stable_argsort_forward() { } std::vector stable_argsort_backward() { } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward"); m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward"); m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward"); m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward"); }