diff --git a/src/decagon_pytorch/convolve/dense.py b/src/decagon_pytorch/convolve/dense.py index 11e2b0d..37cb700 100644 --- a/src/decagon_pytorch/convolve/dense.py +++ b/src/decagon_pytorch/convolve/dense.py @@ -5,8 +5,8 @@ import torch -from .dropout import dropout -from .weights import init_glorot +from ..dropout import dropout +from ..weights import init_glorot from typing import List, Callable diff --git a/src/decagon_pytorch/convolve/sparse.py b/src/decagon_pytorch/convolve/sparse.py index 7cbabf2..c472007 100644 --- a/src/decagon_pytorch/convolve/sparse.py +++ b/src/decagon_pytorch/convolve/sparse.py @@ -5,8 +5,8 @@ import torch -from .dropout import dropout_sparse -from .weights import init_glorot +from ..dropout import dropout_sparse +from ..weights import init_glorot from typing import List, Callable diff --git a/src/decagon_pytorch/convolve/universal.py b/src/decagon_pytorch/convolve/universal.py index 0bcc8c3..f39d1a8 100644 --- a/src/decagon_pytorch/convolve/universal.py +++ b/src/decagon_pytorch/convolve/universal.py @@ -5,9 +5,9 @@ import torch -from .dropout import dropout_sparse, \ +from ..dropout import dropout_sparse, \ dropout -from .weights import init_glorot +from ..weights import init_glorot from typing import List, Callable diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 1f8b7a0..8dee490 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -179,24 +179,24 @@ def test_graph_conv(): assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy()) -def setup_function(fun): - if fun == test_dropout_graph_conv_activation or \ - fun == test_multi_dgca: - print('Disabling dropout for testing...') - setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ - decagon_pytorch.convolve.dropout_sparse - - decagon_pytorch.convolve.dropout = lambda x, keep_prob: x - decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x - - -def teardown_function(fun): - print('Re-enabling dropout...') - if fun == test_dropout_graph_conv_activation or \ - fun == test_multi_dgca: - decagon_pytorch.convolve.dropout, \ - decagon_pytorch.convolve.dropout_sparse = \ - setup_function.old_dropout +# def setup_function(fun): +# if fun == test_dropout_graph_conv_activation or \ +# fun == test_multi_dgca: +# print('Disabling dropout for testing...') +# setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ +# decagon_pytorch.convolve.dropout_sparse +# +# decagon_pytorch.convolve.dropout = lambda x, keep_prob: x +# decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x +# +# +# def teardown_function(fun): +# print('Re-enabling dropout...') +# if fun == test_dropout_graph_conv_activation or \ +# fun == test_multi_dgca: +# decagon_pytorch.convolve.dropout, \ +# decagon_pytorch.convolve.dropout_sparse = \ +# setup_function.old_dropout def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): @@ -211,7 +211,20 @@ def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): return latent -def test_dropout_graph_conv_activation(): +def _disable_dropout(monkeypatch): + monkeypatch.setattr(decagon_pytorch.convolve.dense, 'dropout', + lambda x, keep_prob: x) + monkeypatch.setattr(decagon_pytorch.convolve.sparse, 'dropout_sparse', + lambda x, keep_prob: x) + monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout', + lambda x, keep_prob: x) + monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout_sparse', + lambda x, keep_prob: x) + + +def test_dropout_graph_conv_activation(monkeypatch): + _disable_dropout(monkeypatch) + for i in range(11): keep_prob = i/10. if keep_prob == 0: @@ -243,7 +256,9 @@ def test_dropout_graph_conv_activation(): assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) -def test_multi_dgca(): +def test_multi_dgca(monkeypatch): + _disable_dropout(monkeypatch) + keep_prob = .5 torch.random.manual_seed(0)