|
|
@@ -4,137 +4,136 @@ |
|
|
|
#include <vector>
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
#include "dispatch.h"
|
|
|
|
|
|
|
|
template<typename fun, typename... Ts>
|
|
|
|
void dispatch(torch::Tensor input, Ts&& ... args) {
|
|
|
|
switch(input.type().scalarType()) {
|
|
|
|
case torch::ScalarType::Double:
|
|
|
|
return fun<double>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Float:
|
|
|
|
return fun<float>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Half:
|
|
|
|
throw std::runtime_error("Half-precision float not supported");
|
|
|
|
case torch::ScalarType::ComplexHalf:
|
|
|
|
throw std::runtime_error("Half-precision complex float not supported");
|
|
|
|
case torch::ScalarType::ComplexFloat:
|
|
|
|
return fun<float64_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::ComplexDouble:
|
|
|
|
return fun<float128_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Long:
|
|
|
|
return fun<int64_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Int:
|
|
|
|
return fun<int32_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Short:
|
|
|
|
return fun<int16_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Char:
|
|
|
|
return fun<int8_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Byte:
|
|
|
|
return fun<uint8_t>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::Bool:
|
|
|
|
return fun<bool>(input, std::forward<Ts>(args)...);
|
|
|
|
case torch::ScalarType::QInt32:
|
|
|
|
throw std::runtime_error("QInt32 not supported");
|
|
|
|
case torch::ScalarType::QInt16:
|
|
|
|
throw std::runtime_error("QInt16 not supported");
|
|
|
|
case torch::ScalarType::QInt8:
|
|
|
|
throw std::runtime_error("QInt8 not supported");
|
|
|
|
case torch::ScalarType::BFloat16:
|
|
|
|
throw std::runtime_error("BFloat16 not supported");
|
|
|
|
default:
|
|
|
|
throw std::runtime_error("Unknown scalar type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
template<typename T>
|
|
|
|
struct stable_sort_impl {
|
|
|
|
std::vector<torch::Tensor> operator()(
|
|
|
|
torch::Tensor input,
|
|
|
|
int dim,
|
|
|
|
bool descending,
|
|
|
|
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out
|
|
|
|
) const {
|
|
|
|
|
|
|
|
if (input.is_sparse())
|
|
|
|
throw std::runtime_error("Sparse tensors are not supported");
|
|
|
|
|
|
|
|
std::vector<at::Tensor> stable_sort_forward(
|
|
|
|
torch::Tensor input,
|
|
|
|
int dim,
|
|
|
|
bool descending,
|
|
|
|
torch::optional<torch::Tensor> out = nullptr) {
|
|
|
|
if (input.device().type() != torch::DeviceType::CPU)
|
|
|
|
throw std::runtime_error("Only CPU tensors are supported");
|
|
|
|
|
|
|
|
if (out != torch::nullopt)
|
|
|
|
throw std::runtime_error("out argument is not supported");
|
|
|
|
|
|
|
|
auto in = (dim != -1) ?
|
|
|
|
torch::transpose(input, dim, -1) :
|
|
|
|
input;
|
|
|
|
|
|
|
|
auto X = torch::cat({old_h, input}, /*dim=*/1);
|
|
|
|
auto in_sizes = in.sizes();
|
|
|
|
|
|
|
|
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
|
|
|
|
auto gates = gate_weights.chunk(3, /*dim=*/1);
|
|
|
|
// std::cout << "in_sizes: " << in_sizes << std::endl;
|
|
|
|
|
|
|
|
auto input_gate = torch::sigmoid(gates[0]);
|
|
|
|
auto output_gate = torch::sigmoid(gates[1]);
|
|
|
|
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
|
|
|
|
in = in.view({ -1, in.size(-1) }).contiguous();
|
|
|
|
|
|
|
|
auto new_cell = old_cell + candidate_cell * input_gate;
|
|
|
|
auto new_h = torch::tanh(new_cell) * output_gate;
|
|
|
|
auto in_outer_stride = in.stride(-2);
|
|
|
|
auto in_inner_stride = in.stride(-1);
|
|
|
|
|
|
|
|
return {new_h,
|
|
|
|
new_cell,
|
|
|
|
input_gate,
|
|
|
|
output_gate,
|
|
|
|
candidate_cell,
|
|
|
|
X,
|
|
|
|
gate_weights};
|
|
|
|
}
|
|
|
|
auto pin = static_cast<T*>(in.data_ptr());
|
|
|
|
|
|
|
|
/ tanh'(z) = 1 - tanh^2(z)
|
|
|
|
torch::Tensor d_tanh(torch::Tensor z) {
|
|
|
|
return 1 - z.tanh().pow(2);
|
|
|
|
}
|
|
|
|
auto x = in.clone();
|
|
|
|
|
|
|
|
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
|
|
|
|
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
|
|
|
|
auto e = z.exp();
|
|
|
|
auto mask = (alpha * (e - 1)) < 0;
|
|
|
|
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
|
|
|
|
}
|
|
|
|
auto x_outer_stride = x.stride(-2);
|
|
|
|
auto x_inner_stride = x.stride(-1);
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> stable_sort_backward(
|
|
|
|
torch::Tensor grad_h,
|
|
|
|
torch::Tensor grad_cell,
|
|
|
|
torch::Tensor new_cell,
|
|
|
|
torch::Tensor input_gate,
|
|
|
|
torch::Tensor output_gate,
|
|
|
|
torch::Tensor candidate_cell,
|
|
|
|
torch::Tensor X,
|
|
|
|
torch::Tensor gate_weights,
|
|
|
|
torch::Tensor weights) {
|
|
|
|
auto d_output_gate = torch::tanh(new_cell) * grad_h;
|
|
|
|
auto d_tanh_new_cell = output_gate * grad_h;
|
|
|
|
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
|
|
|
|
|
|
|
|
auto d_old_cell = d_new_cell;
|
|
|
|
auto d_candidate_cell = input_gate * d_new_cell;
|
|
|
|
auto d_input_gate = candidate_cell * d_new_cell;
|
|
|
|
|
|
|
|
auto gates = gate_weights.chunk(3, /*dim=*/1);
|
|
|
|
d_input_gate *= d_sigmoid(gates[0]);
|
|
|
|
d_output_gate *= d_sigmoid(gates[1]);
|
|
|
|
d_candidate_cell *= d_elu(gates[2]);
|
|
|
|
|
|
|
|
auto d_gates =
|
|
|
|
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
|
|
|
|
|
|
|
|
auto d_weights = d_gates.t().mm(X);
|
|
|
|
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
|
|
|
|
|
|
|
|
auto d_X = d_gates.mm(weights);
|
|
|
|
const auto state_size = grad_h.size(1);
|
|
|
|
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
|
|
|
|
auto d_input = d_X.slice(/*dim=*/1, state_size);
|
|
|
|
|
|
|
|
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
|
|
|
|
}
|
|
|
|
auto n_cols = x.size(1);
|
|
|
|
auto n_rows = x.size(0);
|
|
|
|
auto px = static_cast<T*>(x.data_ptr());
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> stable_argsort_forward() {
|
|
|
|
auto y = torch::empty({ n_rows, n_cols },
|
|
|
|
torch::TensorOptions().dtype(torch::kInt64));
|
|
|
|
|
|
|
|
}
|
|
|
|
auto y_outer_stride = y.stride(-2);
|
|
|
|
auto y_inner_stride = y.stride(-1);
|
|
|
|
|
|
|
|
auto py = static_cast<int64_t*>(y.data_ptr());
|
|
|
|
|
|
|
|
if (descending) {
|
|
|
|
#pragma omp parallel for
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
|
|
|
|
std::vector<int64_t> indices(n_cols);
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
indices[k] = k;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::stable_sort(std::begin(indices), std::end(indices),
|
|
|
|
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
|
|
|
|
auto va = pin[i * in_outer_stride + a * in_inner_stride];
|
|
|
|
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
|
|
|
|
return (vb < va);
|
|
|
|
});
|
|
|
|
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
|
|
|
|
px[i * x_outer_stride + k * x_inner_stride] =
|
|
|
|
pin[i * in_outer_stride + indices[k] * in_inner_stride];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> stable_argsort_backward() {
|
|
|
|
} else {
|
|
|
|
#pragma omp parallel for
|
|
|
|
for (decltype(n_rows) i = 0; i < n_rows; i++) {
|
|
|
|
|
|
|
|
std::vector<int64_t> indices(n_cols);
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
indices[k] = k;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::stable_sort(std::begin(indices), std::end(indices),
|
|
|
|
[pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
|
|
|
|
auto va = pin[i * in_outer_stride + a * in_inner_stride];
|
|
|
|
auto vb = pin[i * in_outer_stride + b * in_inner_stride];
|
|
|
|
return (va < vb);
|
|
|
|
});
|
|
|
|
|
|
|
|
for (decltype(n_cols) k = 0; k < n_cols; k++) {
|
|
|
|
py[i * y_outer_stride + k * y_inner_stride] = indices[k];
|
|
|
|
px[i * x_outer_stride + k * x_inner_stride] =
|
|
|
|
pin[i * in_outer_stride + indices[k] * in_inner_stride];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// std::cout << "Here" << std::endl;
|
|
|
|
|
|
|
|
x = x.view(in_sizes);
|
|
|
|
y = y.view(in_sizes);
|
|
|
|
|
|
|
|
x = (dim == -1) ?
|
|
|
|
x :
|
|
|
|
torch::transpose(x, dim, -1).contiguous();
|
|
|
|
|
|
|
|
y = (dim == -1) ?
|
|
|
|
y :
|
|
|
|
torch::transpose(y, dim, -1).contiguous();
|
|
|
|
|
|
|
|
// std::cout << "Here 2" << std::endl;
|
|
|
|
|
|
|
|
return { x, y };
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> stable_sort(
|
|
|
|
torch::Tensor input,
|
|
|
|
int dim = -1,
|
|
|
|
bool descending = false,
|
|
|
|
torch::optional<std::tuple<torch::Tensor, torch::Tensor>> out = torch::nullopt) {
|
|
|
|
|
|
|
|
return dispatch<stable_sort_impl, std::vector<torch::Tensor>>(
|
|
|
|
input, dim, descending, out);
|
|
|
|
}
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
|
m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward");
|
|
|
|
m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward");
|
|
|
|
m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward");
|
|
|
|
m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward");
|
|
|
|
m.def("stable_sort", &stable_sort, "Stable sort",
|
|
|
|
py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false,
|
|
|
|
py::arg("out") = nullptr);
|
|
|
|
}
|