From 283d9387a8549a16bec6a60a7e0934f97d9f1ae4 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Mon, 24 Aug 2020 18:31:26 +0200 Subject: [PATCH] Add torch_stablesort.md. --- src/torch_stablesort/torch_stablesort.md | 29 ++++++++++++++++++++++++ src/torch_stablesort/torch_stablesort.py | 20 ++++++++++++---- 2 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 src/torch_stablesort/torch_stablesort.md diff --git a/src/torch_stablesort/torch_stablesort.md b/src/torch_stablesort/torch_stablesort.md new file mode 100644 index 0000000..e15d2ee --- /dev/null +++ b/src/torch_stablesort/torch_stablesort.md @@ -0,0 +1,29 @@ +# torch_stablesort + +## Introduction + +### Stable sorting algorithms + +Stable sort algorithms sort repeated elements in the same order that they appear in the input. When sorting some kinds of data, only part of the data is examined when determining the sort order. For example, in the card sorting example to the right, the cards are being sorted by their rank, and their suit is being ignored. This allows the possibility of multiple different correctly sorted versions of the original list. Stable sorting algorithms choose one of these, according to the following rule: if two items compare as equal, like the two 5 cards, then their relative order will be preserved, so that if one came before the other in the input, it will also come before the other in the output. + +### PyTorch + +PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. It is free and open-source software released under the Modified BSD license. + +### PyTorch Extensions + +PyTorch provides a plethora of operations related to neural networks, arbitrary tensor algebra, data wrangling and other purposes. However, you may still find yourself in need of a more customized operation. For example, you might want to use a novel activation function you found in a paper, or implement an operation you developed as part of your research. + +The easiest way of integrating such a custom operation in PyTorch is to write it in Python by extending Function and Module as outlined here. This gives you the full power of automatic differentiation (spares you from writing derivative functions) as well as the usual expressiveness of Python. However, there may be times when your operation is better implemented in C++. For example, your code may need to be really fast because it is called very frequently in your model or is very expensive even for few calls. Another plausible reason is that it depends on or interacts with other C or C++ libraries. To address such cases, PyTorch provides a very easy way of writing custom C++ extensions. + +## Implementation + +### setup.py + + + +### dispatch.h + +### torch_stablesort.cpp + +### torch_stablesort.py diff --git a/src/torch_stablesort/torch_stablesort.py b/src/torch_stablesort/torch_stablesort.py index 4a171d0..83b7574 100644 --- a/src/torch_stablesort/torch_stablesort.py +++ b/src/torch_stablesort/torch_stablesort.py @@ -7,12 +7,22 @@ class StableSort(torch.autograd.Function): def forward(ctx, input, dim=-1, descending=False, out=None): values, indices = \ torch_stablesort_cpp.stable_sort(input, dim, descending, out) - ctx.save_for_backward(input, indices) - return values, indices + ctx.save_for_backward(input, indices, torch.tensor(dim)) + return values, indices.detach() @staticmethod def backward(ctx, grad_values, grad_indices): - input, indices = ctx.saved_variables - res = torch.empty_like(grad_values) - res[indices] = grad_values + grad_indices + input, indices, dim = ctx.saved_variables + # print('backward(), grad_indices:', grad_indices, 'indices:', indices, + # 'grad_values:', grad_values) + + res = torch.gather(grad_values, dim, indices) + + # res = torch.empty_like(grad_values) + # print('here') + # res = res.view(-1, res.size(-1)) + # indices = indices.view(-1, indices.size(-1)) + # torch.repeat_interleave(torch.arange(indices.size(0)) + # res[indices] = grad_values # + grad_indices + # print('here 2') return res