IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Add torch_stablesort.md.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
283d9387a8
共有 2 個檔案被更改,包括 44 行新增5 行删除
  1. +29
    -0
      src/torch_stablesort/torch_stablesort.md
  2. +15
    -5
      src/torch_stablesort/torch_stablesort.py

+ 29
- 0
src/torch_stablesort/torch_stablesort.md 查看文件

@@ -0,0 +1,29 @@
# torch_stablesort
## Introduction
### Stable sorting algorithms
Stable sort algorithms sort repeated elements in the same order that they appear in the input. When sorting some kinds of data, only part of the data is examined when determining the sort order. For example, in the card sorting example to the right, the cards are being sorted by their rank, and their suit is being ignored. This allows the possibility of multiple different correctly sorted versions of the original list. Stable sorting algorithms choose one of these, according to the following rule: if two items compare as equal, like the two 5 cards, then their relative order will be preserved, so that if one came before the other in the input, it will also come before the other in the output.
### PyTorch
PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. It is free and open-source software released under the Modified BSD license.
### PyTorch Extensions
PyTorch provides a plethora of operations related to neural networks, arbitrary tensor algebra, data wrangling and other purposes. However, you may still find yourself in need of a more customized operation. For example, you might want to use a novel activation function you found in a paper, or implement an operation you developed as part of your research.
The easiest way of integrating such a custom operation in PyTorch is to write it in Python by extending Function and Module as outlined here. This gives you the full power of automatic differentiation (spares you from writing derivative functions) as well as the usual expressiveness of Python. However, there may be times when your operation is better implemented in C++. For example, your code may need to be really fast because it is called very frequently in your model or is very expensive even for few calls. Another plausible reason is that it depends on or interacts with other C or C++ libraries. To address such cases, PyTorch provides a very easy way of writing custom C++ extensions.
## Implementation
### setup.py
### dispatch.h
### torch_stablesort.cpp
### torch_stablesort.py

+ 15
- 5
src/torch_stablesort/torch_stablesort.py 查看文件

@@ -7,12 +7,22 @@ class StableSort(torch.autograd.Function):
def forward(ctx, input, dim=-1, descending=False, out=None):
values, indices = \
torch_stablesort_cpp.stable_sort(input, dim, descending, out)
ctx.save_for_backward(input, indices)
return values, indices
ctx.save_for_backward(input, indices, torch.tensor(dim))
return values, indices.detach()
@staticmethod
def backward(ctx, grad_values, grad_indices):
input, indices = ctx.saved_variables
res = torch.empty_like(grad_values)
res[indices] = grad_values + grad_indices
input, indices, dim = ctx.saved_variables
# print('backward(), grad_indices:', grad_indices, 'indices:', indices,
# 'grad_values:', grad_values)
res = torch.gather(grad_values, dim, indices)
# res = torch.empty_like(grad_values)
# print('here')
# res = res.view(-1, res.size(-1))
# indices = indices.view(-1, indices.size(-1))
# torch.repeat_interleave(torch.arange(indices.size(0))
# res[indices] = grad_values # + grad_indices
# print('here 2')
return res

Loading…
取消
儲存