@@ -2,6 +2,7 @@ from setuptools import setup, Extension | |||||
from torch.utils import cpp_extension | from torch.utils import cpp_extension | ||||
setup(name='torch_stablesort', | setup(name='torch_stablesort', | ||||
py_modules=['torch_stablesort'], | |||||
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', | ||||
['torch_stablesort.cpp'], | ['torch_stablesort.cpp'], | ||||
extra_compile_args=['-fopenmp', '-ggdb'])], | extra_compile_args=['-fopenmp', '-ggdb'])], | ||||
@@ -0,0 +1,18 @@ | |||||
import torch | |||||
import torch_stablesort_cpp | |||||
class StableSort(torch.autograd.Function): | |||||
@staticmethod | |||||
def forward(ctx, input, dim=-1, descending=False, out=None): | |||||
values, indices = \ | |||||
torch_stablesort_cpp.stable_sort(input, dim, descending, out) | |||||
ctx.save_for_backward(input, indices) | |||||
return values, indices | |||||
@staticmethod | |||||
def backward(ctx, grad_values, grad_indices): | |||||
input, indices = ctx.saved_variables | |||||
res = torch.empty_like(grad_values) | |||||
res[indices] = grad_values + grad_indices | |||||
return res |