IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

42 lines
839B

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .normalize import _sparse_coo_tensor
  7. def dropout_sparse(x, keep_prob):
  8. x = x.coalesce()
  9. i = x._indices()
  10. v = x._values()
  11. size = x.size()
  12. n = keep_prob + torch.rand(len(v))
  13. n = torch.floor(n).to(torch.bool)
  14. i = i[:,n]
  15. v = v[n]
  16. x = _sparse_coo_tensor(i, v, size=size)
  17. return x * (1./keep_prob)
  18. def dropout_dense(x, keep_prob):
  19. x = x.clone().detach()
  20. i = torch.nonzero(x)
  21. n = keep_prob + torch.rand(len(i))
  22. n = (1. - torch.floor(n)).to(torch.bool)
  23. x[i[n, 0], i[n, 1]] = 0.
  24. return x * (1./keep_prob)
  25. def dropout(x, keep_prob):
  26. if x.is_sparse:
  27. return dropout_sparse(x, keep_prob)
  28. else:
  29. return dropout_dense(x, keep_prob)