IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

pirms 4 gadiem
12345678910111213141516171819202122232425262728293031323334353637383940
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. def dropout_sparse(x, keep_prob):
  7. x = x.coalesce()
  8. i = x._indices()
  9. v = x._values()
  10. size = x.size()
  11. n = keep_prob + torch.rand(len(v))
  12. n = torch.floor(n).to(torch.bool)
  13. i = i[:,n]
  14. v = v[n]
  15. x = torch.sparse_coo_tensor(i, v, size=size)
  16. return x * (1./keep_prob)
  17. def dropout_dense(x, keep_prob):
  18. x = x.clone().detach()
  19. i = torch.nonzero(x)
  20. n = keep_prob + torch.rand(len(i))
  21. n = (1. - torch.floor(n)).to(torch.bool)
  22. x[i[n, 0], i[n, 1]] = 0.
  23. return x * (1./keep_prob)
  24. def dropout(x, keep_prob):
  25. if x.is_sparse:
  26. return dropout_sparse(x, keep_prob)
  27. else:
  28. return dropout_dense(x, keep_prob)