import torch import numpy as np def dfill(a): n = torch.numel(a) b = torch.cat([ torch.tensor([0]), torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, torch.tensor([n]) ]) # print('b:',b) res = torch.arange(n)[b[:-1]] res = torch.repeat_interleave(res, b[1:] - b[:-1]) return res def argunsort(s): n = torch.numel(s) u = torch.empty(n, dtype=torch.int64) u[s] = torch.arange(n) return u def cumcount(a): n = torch.numel(a) s = np.argsort(a.detach().cpu().numpy(), kind='mergesort') s = torch.tensor(s, device=a.device) i = argunsort(s) b = a[s] return (torch.arange(n) - dfill(b))[i]