IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

31 lines
677B

  1. import torch
  2. import numpy as np
  3. def dfill(a):
  4. n = torch.numel(a)
  5. b = torch.cat([
  6. torch.tensor([0]),
  7. torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
  8. torch.tensor([n])
  9. ])
  10. res = torch.arange(n)[b[:-1]]
  11. res = torch.repeat_interleave(res, b[1:] - b[:-1])
  12. return res
  13. def argunsort(s):
  14. n = torch.numel(s)
  15. u = torch.empty(n, dtype=torch.int64)
  16. u[s] = torch.arange(n)
  17. return u
  18. def cumcount(a):
  19. n = torch.numel(a)
  20. s = np.argsort(a.detach().cpu().numpy())
  21. s = torch.tensor(s, device=a.device)
  22. i = argunsort(s)
  23. b = a[s]
  24. return (torch.arange(n) - dfill(b))[i]