@@ -0,0 +1,29 @@ | |||||
# torch_stablesort | |||||
## Introduction | |||||
### Stable sorting algorithms | |||||
Stable sort algorithms sort repeated elements in the same order that they appear in the input. When sorting some kinds of data, only part of the data is examined when determining the sort order. For example, in the card sorting example to the right, the cards are being sorted by their rank, and their suit is being ignored. This allows the possibility of multiple different correctly sorted versions of the original list. Stable sorting algorithms choose one of these, according to the following rule: if two items compare as equal, like the two 5 cards, then their relative order will be preserved, so that if one came before the other in the input, it will also come before the other in the output. | |||||
### PyTorch | |||||
PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. It is free and open-source software released under the Modified BSD license. | |||||
### PyTorch Extensions | |||||
PyTorch provides a plethora of operations related to neural networks, arbitrary tensor algebra, data wrangling and other purposes. However, you may still find yourself in need of a more customized operation. For example, you might want to use a novel activation function you found in a paper, or implement an operation you developed as part of your research. | |||||
The easiest way of integrating such a custom operation in PyTorch is to write it in Python by extending Function and Module as outlined here. This gives you the full power of automatic differentiation (spares you from writing derivative functions) as well as the usual expressiveness of Python. However, there may be times when your operation is better implemented in C++. For example, your code may need to be really fast because it is called very frequently in your model or is very expensive even for few calls. Another plausible reason is that it depends on or interacts with other C or C++ libraries. To address such cases, PyTorch provides a very easy way of writing custom C++ extensions. | |||||
## Implementation | |||||
### setup.py | |||||
### dispatch.h | |||||
### torch_stablesort.cpp | |||||
### torch_stablesort.py |
@@ -7,12 +7,22 @@ class StableSort(torch.autograd.Function): | |||||
def forward(ctx, input, dim=-1, descending=False, out=None): | def forward(ctx, input, dim=-1, descending=False, out=None): | ||||
values, indices = \ | values, indices = \ | ||||
torch_stablesort_cpp.stable_sort(input, dim, descending, out) | torch_stablesort_cpp.stable_sort(input, dim, descending, out) | ||||
ctx.save_for_backward(input, indices) | |||||
return values, indices | |||||
ctx.save_for_backward(input, indices, torch.tensor(dim)) | |||||
return values, indices.detach() | |||||
@staticmethod | @staticmethod | ||||
def backward(ctx, grad_values, grad_indices): | def backward(ctx, grad_values, grad_indices): | ||||
input, indices = ctx.saved_variables | |||||
res = torch.empty_like(grad_values) | |||||
res[indices] = grad_values + grad_indices | |||||
input, indices, dim = ctx.saved_variables | |||||
# print('backward(), grad_indices:', grad_indices, 'indices:', indices, | |||||
# 'grad_values:', grad_values) | |||||
res = torch.gather(grad_values, dim, indices) | |||||
# res = torch.empty_like(grad_values) | |||||
# print('here') | |||||
# res = res.view(-1, res.size(-1)) | |||||
# indices = indices.view(-1, indices.size(-1)) | |||||
# torch.repeat_interleave(torch.arange(indices.size(0)) | |||||
# res[indices] = grad_values # + grad_indices | |||||
# print('here 2') | |||||
return res | return res |