IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

test_dropout.py 889B

12345678910111213141516171819202122232425262728293031323334
  1. from decagon_pytorch.dropout import dropout_sparse
  2. import torch
  3. import numpy as np
  4. def dropout_dense(a, keep_prob):
  5. i = np.array(np.where(a))
  6. v = a[i[0, :], i[1, :]]
  7. # torch.random.manual_seed(0)
  8. n = keep_prob + torch.rand(len(v))
  9. n = torch.floor(n).to(torch.bool)
  10. i = i[:, n]
  11. v = v[n]
  12. x = torch.sparse_coo_tensor(i, v, size=a.shape)
  13. return x * (1./keep_prob)
  14. def test_dropout_sparse():
  15. for i in range(11):
  16. torch.random.manual_seed(i)
  17. a = torch.rand((5, 10))
  18. a[a < .5] = 0
  19. keep_prob=i/10. + np.finfo(np.float32).eps
  20. torch.random.manual_seed(i)
  21. b = dropout_dense(a, keep_prob=keep_prob)
  22. torch.random.manual_seed(i)
  23. c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob)
  24. assert np.all(np.array(b.to_dense()) == np.array(c.to_dense()))