IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
4.7KB

  1. #include <torch/extension.h>
  2. #include <iostream>
  3. #include <vector>
  4. #include <algorithm>
  5. template<typename fun, typename... Ts>
  6. void dispatch(torch::Tensor input, Ts&& ... args) {
  7. switch(input.type().scalarType()) {
  8. case torch::ScalarType::Double:
  9. return fun<double>(input, std::forward<Ts>(args)...);
  10. case torch::ScalarType::Float:
  11. return fun<float>(input, std::forward<Ts>(args)...);
  12. case torch::ScalarType::Half:
  13. throw std::runtime_error("Half-precision float not supported");
  14. case torch::ScalarType::ComplexHalf:
  15. throw std::runtime_error("Half-precision complex float not supported");
  16. case torch::ScalarType::ComplexFloat:
  17. return fun<float64_t>(input, std::forward<Ts>(args)...);
  18. case torch::ScalarType::ComplexDouble:
  19. return fun<float128_t>(input, std::forward<Ts>(args)...);
  20. case torch::ScalarType::Long:
  21. return fun<int64_t>(input, std::forward<Ts>(args)...);
  22. case torch::ScalarType::Int:
  23. return fun<int32_t>(input, std::forward<Ts>(args)...);
  24. case torch::ScalarType::Short:
  25. return fun<int16_t>(input, std::forward<Ts>(args)...);
  26. case torch::ScalarType::Char:
  27. return fun<int8_t>(input, std::forward<Ts>(args)...);
  28. case torch::ScalarType::Byte:
  29. return fun<uint8_t>(input, std::forward<Ts>(args)...);
  30. case torch::ScalarType::Bool:
  31. return fun<bool>(input, std::forward<Ts>(args)...);
  32. case torch::ScalarType::QInt32:
  33. throw std::runtime_error("QInt32 not supported");
  34. case torch::ScalarType::QInt16:
  35. throw std::runtime_error("QInt16 not supported");
  36. case torch::ScalarType::QInt8:
  37. throw std::runtime_error("QInt8 not supported");
  38. case torch::ScalarType::BFloat16:
  39. throw std::runtime_error("BFloat16 not supported");
  40. default:
  41. throw std::runtime_error("Unknown scalar type");
  42. }
  43. }
  44. std::vector<at::Tensor> stable_sort_forward(
  45. torch::Tensor input,
  46. int dim,
  47. bool descending,
  48. torch::optional<torch::Tensor> out = nullptr) {
  49. auto X = torch::cat({old_h, input}, /*dim=*/1);
  50. auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
  51. auto gates = gate_weights.chunk(3, /*dim=*/1);
  52. auto input_gate = torch::sigmoid(gates[0]);
  53. auto output_gate = torch::sigmoid(gates[1]);
  54. auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
  55. auto new_cell = old_cell + candidate_cell * input_gate;
  56. auto new_h = torch::tanh(new_cell) * output_gate;
  57. return {new_h,
  58. new_cell,
  59. input_gate,
  60. output_gate,
  61. candidate_cell,
  62. X,
  63. gate_weights};
  64. }
  65. / tanh'(z) = 1 - tanh^2(z)
  66. torch::Tensor d_tanh(torch::Tensor z) {
  67. return 1 - z.tanh().pow(2);
  68. }
  69. // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
  70. torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
  71. auto e = z.exp();
  72. auto mask = (alpha * (e - 1)) < 0;
  73. return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
  74. }
  75. std::vector<torch::Tensor> stable_sort_backward(
  76. torch::Tensor grad_h,
  77. torch::Tensor grad_cell,
  78. torch::Tensor new_cell,
  79. torch::Tensor input_gate,
  80. torch::Tensor output_gate,
  81. torch::Tensor candidate_cell,
  82. torch::Tensor X,
  83. torch::Tensor gate_weights,
  84. torch::Tensor weights) {
  85. auto d_output_gate = torch::tanh(new_cell) * grad_h;
  86. auto d_tanh_new_cell = output_gate * grad_h;
  87. auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
  88. auto d_old_cell = d_new_cell;
  89. auto d_candidate_cell = input_gate * d_new_cell;
  90. auto d_input_gate = candidate_cell * d_new_cell;
  91. auto gates = gate_weights.chunk(3, /*dim=*/1);
  92. d_input_gate *= d_sigmoid(gates[0]);
  93. d_output_gate *= d_sigmoid(gates[1]);
  94. d_candidate_cell *= d_elu(gates[2]);
  95. auto d_gates =
  96. torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
  97. auto d_weights = d_gates.t().mm(X);
  98. auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
  99. auto d_X = d_gates.mm(weights);
  100. const auto state_size = grad_h.size(1);
  101. auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
  102. auto d_input = d_X.slice(/*dim=*/1, state_size);
  103. return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
  104. }
  105. std::vector<torch::Tensor> stable_argsort_forward() {
  106. }
  107. std::vector<torch::Tensor> stable_argsort_backward() {
  108. }
  109. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  110. m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward");
  111. m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward");
  112. m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward");
  113. m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward");
  114. }