IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

19 lines
378B

  1. import torch
  2. def dropout_sparse(x, keep_prob):
  3. """Dropout for sparse tensors.
  4. """
  5. x = x.coalesce()
  6. i = x._indices()
  7. v = x._values()
  8. size = x.size()
  9. n = keep_prob + torch.rand(len(v))
  10. n = torch.floor(n).to(torch.bool)
  11. i = i[:,n]
  12. v = v[n]
  13. x = torch.sparse_coo_tensor(i, v, size=size)
  14. return x * (1./keep_prob)