import torch def dropout_sparse(x, keep_prob): """Dropout for sparse tensors. """ x = x.coalesce() i = x._indices() v = x._values() size = x.size() n = keep_prob + torch.rand(len(v)) n = torch.floor(n).to(torch.bool) i = i[:,n] v = v[n] x = torch.sparse_coo_tensor(i, v, size=size) return x * (1./keep_prob)