IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

12345678910111213141516171819202122232425262728293031323334353637
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. def dropout_sparse(x, keep_prob):
  7. """Dropout for sparse tensors.
  8. """
  9. x = x.coalesce()
  10. i = x._indices()
  11. v = x._values()
  12. size = x.size()
  13. n = keep_prob + torch.rand(len(v))
  14. n = torch.floor(n).to(torch.bool)
  15. i = i[:,n]
  16. v = v[n]
  17. x = torch.sparse_coo_tensor(i, v, size=size)
  18. return x * (1./keep_prob)
  19. def dropout(x, keep_prob):
  20. """Dropout for dense tensors.
  21. """
  22. shape = x.shape
  23. x = torch.flatten(x)
  24. n = keep_prob + torch.rand(len(x))
  25. n = (1. - torch.floor(n)).to(torch.bool)
  26. x[n] = 0
  27. x = torch.reshape(x, shape)
  28. # x = torch.nn.functional.dropout(x, p=1.-keep_prob)
  29. return x * (1./keep_prob)