|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- #include <torch/extension.h>
-
- #include <iostream>
- #include <vector>
- #include <algorithm>
-
-
- template<typename fun, typename... Ts>
- void dispatch(torch::Tensor input, Ts&& ... args) {
- switch(input.type().scalarType()) {
- case torch::ScalarType::Double:
- return fun<double>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Float:
- return fun<float>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Half:
- throw std::runtime_error("Half-precision float not supported");
- case torch::ScalarType::ComplexHalf:
- throw std::runtime_error("Half-precision complex float not supported");
- case torch::ScalarType::ComplexFloat:
- return fun<float64_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::ComplexDouble:
- return fun<float128_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Long:
- return fun<int64_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Int:
- return fun<int32_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Short:
- return fun<int16_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Char:
- return fun<int8_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Byte:
- return fun<uint8_t>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::Bool:
- return fun<bool>(input, std::forward<Ts>(args)...);
- case torch::ScalarType::QInt32:
- throw std::runtime_error("QInt32 not supported");
- case torch::ScalarType::QInt16:
- throw std::runtime_error("QInt16 not supported");
- case torch::ScalarType::QInt8:
- throw std::runtime_error("QInt8 not supported");
- case torch::ScalarType::BFloat16:
- throw std::runtime_error("BFloat16 not supported");
- default:
- throw std::runtime_error("Unknown scalar type");
- }
- }
-
-
- std::vector<at::Tensor> stable_sort_forward(
- torch::Tensor input,
- int dim,
- bool descending,
- torch::optional<torch::Tensor> out = nullptr) {
-
-
-
- auto X = torch::cat({old_h, input}, /*dim=*/1);
-
- auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
- auto gates = gate_weights.chunk(3, /*dim=*/1);
-
- auto input_gate = torch::sigmoid(gates[0]);
- auto output_gate = torch::sigmoid(gates[1]);
- auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
-
- auto new_cell = old_cell + candidate_cell * input_gate;
- auto new_h = torch::tanh(new_cell) * output_gate;
-
- return {new_h,
- new_cell,
- input_gate,
- output_gate,
- candidate_cell,
- X,
- gate_weights};
- }
-
- / tanh'(z) = 1 - tanh^2(z)
- torch::Tensor d_tanh(torch::Tensor z) {
- return 1 - z.tanh().pow(2);
- }
-
- // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
- torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
- auto e = z.exp();
- auto mask = (alpha * (e - 1)) < 0;
- return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
- }
-
- std::vector<torch::Tensor> stable_sort_backward(
- torch::Tensor grad_h,
- torch::Tensor grad_cell,
- torch::Tensor new_cell,
- torch::Tensor input_gate,
- torch::Tensor output_gate,
- torch::Tensor candidate_cell,
- torch::Tensor X,
- torch::Tensor gate_weights,
- torch::Tensor weights) {
- auto d_output_gate = torch::tanh(new_cell) * grad_h;
- auto d_tanh_new_cell = output_gate * grad_h;
- auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
-
- auto d_old_cell = d_new_cell;
- auto d_candidate_cell = input_gate * d_new_cell;
- auto d_input_gate = candidate_cell * d_new_cell;
-
- auto gates = gate_weights.chunk(3, /*dim=*/1);
- d_input_gate *= d_sigmoid(gates[0]);
- d_output_gate *= d_sigmoid(gates[1]);
- d_candidate_cell *= d_elu(gates[2]);
-
- auto d_gates =
- torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
-
- auto d_weights = d_gates.t().mm(X);
- auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
-
- auto d_X = d_gates.mm(weights);
- const auto state_size = grad_h.size(1);
- auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
- auto d_input = d_X.slice(/*dim=*/1, state_size);
-
- return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
- }
-
- std::vector<torch::Tensor> stable_argsort_forward() {
-
- }
-
- std::vector<torch::Tensor> stable_argsort_backward() {
-
- }
-
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward");
- m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward");
- m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward");
- m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward");
- }
|