From 87953842a63d7b12319ab108e5f31e2ee6c091aa Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Mon, 24 Aug 2020 13:28:53 +0200 Subject: [PATCH] Try to implement gradient. --- src/torch_stablesort/setup.py | 1 + src/torch_stablesort/torch_stablesort.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 src/torch_stablesort/torch_stablesort.py diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py index 163d8ff..a45a47f 100644 --- a/src/torch_stablesort/setup.py +++ b/src/torch_stablesort/setup.py @@ -2,6 +2,7 @@ from setuptools import setup, Extension from torch.utils import cpp_extension setup(name='torch_stablesort', + py_modules=['torch_stablesort'], ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp', ['torch_stablesort.cpp'], extra_compile_args=['-fopenmp', '-ggdb'])], diff --git a/src/torch_stablesort/torch_stablesort.py b/src/torch_stablesort/torch_stablesort.py new file mode 100644 index 0000000..4a171d0 --- /dev/null +++ b/src/torch_stablesort/torch_stablesort.py @@ -0,0 +1,18 @@ +import torch +import torch_stablesort_cpp + + +class StableSort(torch.autograd.Function): + @staticmethod + def forward(ctx, input, dim=-1, descending=False, out=None): + values, indices = \ + torch_stablesort_cpp.stable_sort(input, dim, descending, out) + ctx.save_for_backward(input, indices) + return values, indices + + @staticmethod + def backward(ctx, grad_values, grad_indices): + input, indices = ctx.saved_variables + res = torch.empty_like(grad_values) + res[indices] = grad_values + grad_indices + return res