IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_dropout.

master
Stanislaw Adaszewski 3 years ago
parent
commit
9defd6c9e4
2 changed files with 27 additions and 1 deletions
  1. +1
    -1
      src/triacontagon/dropout.py
  2. +26
    -0
      tests/triacontagon/test_dropout.py

+ 1
- 1
src/triacontagon/dropout.py View File

@@ -26,7 +26,7 @@ def dropout_sparse(x, keep_prob):
def dropout_dense(x, keep_prob):
# print('dropout_dense()')
x = x.clone()
i = torch.nonzero(x)
i = torch.nonzero(x, as_tuple=False)
n = keep_prob + torch.rand(len(i))
n = (1. - torch.floor(n)).to(torch.bool)


+ 26
- 0
tests/triacontagon/test_dropout.py View File

@@ -0,0 +1,26 @@
from triacontagon.dropout import dropout_sparse, \
dropout_dense
import torch
import numpy as np
def test_dropout_01():
for i in range(11):
torch.random.manual_seed(i)
a = torch.rand((5, 10))
a[a < .5] = 0
keep_prob=i/10. + np.finfo(np.float32).eps
torch.random.manual_seed(i)
b = dropout_dense(a, keep_prob=keep_prob)
torch.random.manual_seed(i)
c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob)
print('keep_prob:', keep_prob)
print('a:', a.detach().cpu().numpy())
print('b:', b.detach().cpu().numpy())
print('c:', c, c.to_dense().detach().cpu().numpy())
assert torch.all(b == c.to_dense())

Loading…
Cancel
Save