IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

dropout.py 862B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .normalize import _sparse_coo_tensor
  7. def dropout_sparse(x, keep_prob):
  8. x = x.coalesce()
  9. i = x._indices()
  10. v = x._values()
  11. size = x.size()
  12. n = keep_prob + torch.rand(len(v))
  13. n = torch.floor(n).to(torch.bool)
  14. i = i[:,n]
  15. v = v[n]
  16. x = _sparse_coo_tensor(i, v, size=size)
  17. return x * (1./keep_prob)
  18. def dropout_dense(x, keep_prob):
  19. # print('dropout_dense()')
  20. x = x.clone()
  21. i = torch.nonzero(x)
  22. n = keep_prob + torch.rand(len(i))
  23. n = (1. - torch.floor(n)).to(torch.bool)
  24. x[i[n, 0], i[n, 1]] = 0.
  25. return x * (1./keep_prob)
  26. def dropout(x, keep_prob):
  27. if x.is_sparse:
  28. return dropout_sparse(x, keep_prob)
  29. else:
  30. return dropout_dense(x, keep_prob)