| @@ -2,6 +2,7 @@ from setuptools import setup, Extension | |||||
| from torch.utils import cpp_extension | from torch.utils import cpp_extension | ||||
| setup(name='torch_stablesort', | setup(name='torch_stablesort', | ||||
| py_modules=['torch_stablesort'], | |||||
| ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ||||
| ['torch_stablesort.cpp'], | ['torch_stablesort.cpp'], | ||||
| extra_compile_args=['-fopenmp', '-ggdb'])], | extra_compile_args=['-fopenmp', '-ggdb'])], | ||||
| @@ -0,0 +1,18 @@ | |||||
| import torch | |||||
| import torch_stablesort_cpp | |||||
| class StableSort(torch.autograd.Function): | |||||
| @staticmethod | |||||
| def forward(ctx, input, dim=-1, descending=False, out=None): | |||||
| values, indices = \ | |||||
| torch_stablesort_cpp.stable_sort(input, dim, descending, out) | |||||
| ctx.save_for_backward(input, indices) | |||||
| return values, indices | |||||
| @staticmethod | |||||
| def backward(ctx, grad_values, grad_indices): | |||||
| input, indices = ctx.saved_variables | |||||
| res = torch.empty_like(grad_values) | |||||
| res[indices] = grad_values + grad_indices | |||||
| return res | |||||