IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Try to implement gradient.

master
Stanislaw Adaszewski 3 years ago
parent
commit
87953842a6
2 changed files with 19 additions and 0 deletions
  1. +1
    -0
      src/torch_stablesort/setup.py
  2. +18
    -0
      src/torch_stablesort/torch_stablesort.py

+ 1
- 0
src/torch_stablesort/setup.py View File

@@ -2,6 +2,7 @@ from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='torch_stablesort',
py_modules=['torch_stablesort'],
ext_modules=[cpp_extension.CppExtension('torch_stablesort_cpp',
['torch_stablesort.cpp'],
extra_compile_args=['-fopenmp', '-ggdb'])],


+ 18
- 0
src/torch_stablesort/torch_stablesort.py View File

@@ -0,0 +1,18 @@
import torch
import torch_stablesort_cpp
class StableSort(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim=-1, descending=False, out=None):
values, indices = \
torch_stablesort_cpp.stable_sort(input, dim, descending, out)
ctx.save_for_backward(input, indices)
return values, indices
@staticmethod
def backward(ctx, grad_values, grad_indices):
input, indices = ctx.saved_variables
res = torch.empty_like(grad_values)
res[indices] = grad_values + grad_indices
return res

Loading…
Cancel
Save