IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Start torch_stablesort.0

master
Stanislaw Adaszewski pirms 4 gadiem
vecāks
revīzija
37c5db1330
3 mainītis faili ar 180 papildinājumiem un 0 dzēšanām
  1. +33
    -0
      src/torch_stablesort/dispatch_test.cpp
  2. +7
    -0
      src/torch_stablesort/setup.py
  3. +140
    -0
      src/torch_stablesort/torch_stablesort.cpp

+ 33
- 0
src/torch_stablesort/dispatch_test.cpp Parādīt failu

@@ -0,0 +1,33 @@
#include <utility>
#include <iostream>
template<typename fun, typename... Ts>
class Dispatcher {
public:
void dispatch(int x, Ts&& ...args) {
if (x > 0)
call<int>(std::forward<Ts>(args)...);
else
call<double>(std::forward<Ts>(args)...);
}
protected:
template<typename T>
void call(Ts&& ...args) {
throw std::runtime_error("Not implemented");
}
};
class bla {};
template<> template<typename T>
void Dispatcher<bla, int, char, double>::call(int&& a, char&& b, double&& c) {
std::cout << typeid(T).name() << " " << a << " " << b << " " << c << std::endl;
}
main() {
std::cout << "main()" << std::endl;
Dispatcher<bla, int, char, double> d;
d.dispatch(5, 1, 'a', 5.5);
d.dispatch(-5, 1, 'a', 5.5);
}

+ 7
- 0
src/torch_stablesort/setup.py Parādīt failu

@@ -0,0 +1,7 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='torch_stablesort',
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp',
['torch_stablesort.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})

+ 140
- 0
src/torch_stablesort/torch_stablesort.cpp Parādīt failu

@@ -0,0 +1,140 @@
#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <algorithm>
template<typename fun, typename... Ts>
void dispatch(torch::Tensor input, Ts&& ... args) {
switch(input.type().scalarType()) {
case torch::ScalarType::Double:
return fun<double>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Float:
return fun<float>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Half:
throw std::runtime_error("Half-precision float not supported");
case torch::ScalarType::ComplexHalf:
throw std::runtime_error("Half-precision complex float not supported");
case torch::ScalarType::ComplexFloat:
return fun<float64_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::ComplexDouble:
return fun<float128_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Long:
return fun<int64_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Int:
return fun<int32_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Short:
return fun<int16_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Char:
return fun<int8_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Byte:
return fun<uint8_t>(input, std::forward<Ts>(args)...);
case torch::ScalarType::Bool:
return fun<bool>(input, std::forward<Ts>(args)...);
case torch::ScalarType::QInt32:
throw std::runtime_error("QInt32 not supported");
case torch::ScalarType::QInt16:
throw std::runtime_error("QInt16 not supported");
case torch::ScalarType::QInt8:
throw std::runtime_error("QInt8 not supported");
case torch::ScalarType::BFloat16:
throw std::runtime_error("BFloat16 not supported");
default:
throw std::runtime_error("Unknown scalar type");
}
}
std::vector<at::Tensor> stable_sort_forward(
torch::Tensor input,
int dim,
bool descending,
torch::optional<torch::Tensor> out = nullptr) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = torch::sigmoid(gates[0]);
auto output_gate = torch::sigmoid(gates[1]);
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = torch::tanh(new_cell) * output_gate;
return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}
/ tanh'(z) = 1 - tanh^2(z)
torch::Tensor d_tanh(torch::Tensor z) {
return 1 - z.tanh().pow(2);
}
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<torch::Tensor> stable_sort_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_output_gate = torch::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
std::vector<torch::Tensor> stable_argsort_forward() {
}
std::vector<torch::Tensor> stable_argsort_backward() {
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("stable_sort_forward", &stable_sort_forward, "Stable sort forward");
m.def("stable_sort_backward", &stable_sort_backward, "Stable sort backward");
m.def("stable_argsort_forward", &stable_argsort_forward, "Stable argsort forward");
m.def("stable_argsort_backward", &stable_argsort_backward, "Stable argsort backward");
}

Notiek ielāde…
Atcelt
Saglabāt