IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

32 lines
716B

  1. import torch
  2. import numpy as np
  3. def dfill(a):
  4. n = torch.numel(a)
  5. b = torch.cat([
  6. torch.tensor([0]),
  7. torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
  8. torch.tensor([n])
  9. ])
  10. # print('b:',b)
  11. res = torch.arange(n)[b[:-1]]
  12. res = torch.repeat_interleave(res, b[1:] - b[:-1])
  13. return res
  14. def argunsort(s):
  15. n = torch.numel(s)
  16. u = torch.empty(n, dtype=torch.int64)
  17. u[s] = torch.arange(n)
  18. return u
  19. def cumcount(a):
  20. n = torch.numel(a)
  21. s = np.argsort(a.detach().cpu().numpy(), kind='mergesort')
  22. s = torch.tensor(s, device=a.device)
  23. i = argunsort(s)
  24. b = a[s]
  25. return (torch.arange(n) - dfill(b))[i]