diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index 163d8ff..a45a47f 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -2,6 +2,7 @@ from setuptools import setup, Extension from torch.utils import cpp_extension setup(name='torch_stablesort', + py_modules=['torch_stablesort'], ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', ['torch_stablesort.cpp'], extra_compile_args=['-fopenmp', '-ggdb'])], diff --git a/src/torch_stablesort/torch_stablesort.py b/src/torch_stablesort/torch_stablesort.py new file mode 100644 index 0000000..4a171d0 --- /dev/null +++ b/src/torch_stablesort/torch_stablesort.py @@ -0,0 +1,18 @@ +import torch +import torch_stablesort_cpp + + +class StableSort(torch.autograd.Function): + @staticmethod + def forward(ctx, input, dim=-1, descending=False, out=None): + values, indices = \ + torch_stablesort_cpp.stable_sort(input, dim, descending, out) + ctx.save_for_backward(input, indices) + return values, indices + + @staticmethod + def backward(ctx, grad_values, grad_indices): + input, indices = ctx.saved_variables + res = torch.empty_like(grad_values) + res[indices] = grad_values + grad_indices + return res