IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

34 lines
648B

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. def dropout_sparse(x, keep_prob):
  7. x = x.coalesce()
  8. i = x._indices()
  9. v = x._values()
  10. size = x.size()
  11. n = keep_prob + torch.rand(len(v))
  12. n = torch.floor(n).to(torch.bool)
  13. i = i[:,n]
  14. v = v[n]
  15. x = torch.sparse_coo_tensor(i, v, size=size)
  16. return x * (1./keep_prob)
  17. def dropout_dense(x, keep_prob):
  18. x = x.clone().detach()
  19. i = torch.nonzero(x)
  20. n = keep_prob + torch.rand(len(i))
  21. n = (1. - torch.floor(n)).to(torch.bool)
  22. x[i[n, 0], i[n, 1]] = 0.
  23. return x * (1./keep_prob)