|
- import torch
- import numpy as np
-
-
- def dfill(a):
- n = torch.numel(a)
- b = torch.cat([
- torch.tensor([0]),
- torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
- torch.tensor([n])
- ])
- print('b:',b)
- res = torch.arange(n)[b[:-1]]
- res = torch.repeat_interleave(res, b[1:] - b[:-1])
- return res
-
-
- def argunsort(s):
- n = torch.numel(s)
- u = torch.empty(n, dtype=torch.int64)
- u[s] = torch.arange(n)
- return u
-
-
- def cumcount(a):
- n = torch.numel(a)
- s = np.argsort(a.detach().cpu().numpy())
- s = torch.tensor(s, device=a.device)
- i = argunsort(s)
- b = a[s]
- return (torch.arange(n) - dfill(b))[i]
|