IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

32 行
723B

  1. import torch
  2. def dropout_sparse(x, keep_prob):
  3. """Dropout for sparse tensors.
  4. """
  5. x = x.coalesce()
  6. i = x._indices()
  7. v = x._values()
  8. size = x.size()
  9. n = keep_prob + torch.rand(len(v))
  10. n = torch.floor(n).to(torch.bool)
  11. i = i[:,n]
  12. v = v[n]
  13. x = torch.sparse_coo_tensor(i, v, size=size)
  14. return x * (1./keep_prob)
  15. def dropout(x, keep_prob):
  16. """Dropout for dense tensors.
  17. """
  18. shape = x.shape
  19. x = torch.flatten(x)
  20. n = keep_prob + torch.rand(len(x))
  21. n = (1. - torch.floor(n)).to(torch.bool)
  22. x[n] = 0
  23. x = torch.reshape(x, shape)
  24. # x = torch.nn.functional.dropout(x, p=1.-keep_prob)
  25. return x * (1./keep_prob)