diff --git a/src/triacontagon/dropout.py b/src/triacontagon/dropout.py index 2fb8728..e9dde92 100644 --- a/src/triacontagon/dropout.py +++ b/src/triacontagon/dropout.py @@ -26,7 +26,7 @@ def dropout_sparse(x, keep_prob): def dropout_dense(x, keep_prob): # print('dropout_dense()') x = x.clone() - i = torch.nonzero(x) + i = torch.nonzero(x, as_tuple=False) n = keep_prob + torch.rand(len(i)) n = (1. - torch.floor(n)).to(torch.bool) diff --git a/tests/triacontagon/test_dropout.py b/tests/triacontagon/test_dropout.py new file mode 100644 index 0000000..abdb04c --- /dev/null +++ b/tests/triacontagon/test_dropout.py @@ -0,0 +1,26 @@ +from triacontagon.dropout import dropout_sparse, \ + dropout_dense +import torch +import numpy as np + + +def test_dropout_01(): + for i in range(11): + torch.random.manual_seed(i) + a = torch.rand((5, 10)) + a[a < .5] = 0 + + keep_prob=i/10. + np.finfo(np.float32).eps + + torch.random.manual_seed(i) + b = dropout_dense(a, keep_prob=keep_prob) + + torch.random.manual_seed(i) + c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob) + + print('keep_prob:', keep_prob) + print('a:', a.detach().cpu().numpy()) + print('b:', b.detach().cpu().numpy()) + print('c:', c, c.to_dense().detach().cpu().numpy()) + + assert torch.all(b == c.to_dense())