|
123456789101112131415161718 |
- import torch
-
-
- def dropout_sparse(x, keep_prob):
- """Dropout for sparse tensors.
- """
- x = x.coalesce()
- i = x._indices()
- v = x._values()
- size = x.size()
-
- n = keep_prob + torch.rand(len(v))
- n = torch.floor(n).to(torch.bool)
- i = i[:,n]
- v = v[n]
- x = torch.sparse_coo_tensor(i, v, size=size)
-
- return x * (1./keep_prob)
|