IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Fix tests.

master
Stanislaw Adaszewski 3 years ago
parent
commit
241a2f92b9
4 changed files with 41 additions and 26 deletions
  1. +2
    -2
      src/decagon_pytorch/convolve/dense.py
  2. +2
    -2
      src/decagon_pytorch/convolve/sparse.py
  3. +2
    -2
      src/decagon_pytorch/convolve/universal.py
  4. +35
    -20
      tests/decagon_pytorch/test_convolve.py

+ 2
- 2
src/decagon_pytorch/convolve/dense.py View File

@@ -5,8 +5,8 @@
import torch
from .dropout import dropout
from .weights import init_glorot
from ..dropout import dropout
from ..weights import init_glorot
from typing import List, Callable


+ 2
- 2
src/decagon_pytorch/convolve/sparse.py View File

@@ -5,8 +5,8 @@
import torch
from .dropout import dropout_sparse
from .weights import init_glorot
from ..dropout import dropout_sparse
from ..weights import init_glorot
from typing import List, Callable


+ 2
- 2
src/decagon_pytorch/convolve/universal.py View File

@@ -5,9 +5,9 @@
import torch
from .dropout import dropout_sparse, \
from ..dropout import dropout_sparse, \
dropout
from .weights import init_glorot
from ..weights import init_glorot
from typing import List, Callable


+ 35
- 20
tests/decagon_pytorch/test_convolve.py View File

@@ -179,24 +179,24 @@ def test_graph_conv():
assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy())
def setup_function(fun):
if fun == test_dropout_graph_conv_activation or \
fun == test_multi_dgca:
print('Disabling dropout for testing...')
setup_function.old_dropout = decagon_pytorch.convolve.dropout, \
decagon_pytorch.convolve.dropout_sparse
decagon_pytorch.convolve.dropout = lambda x, keep_prob: x
decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x
def teardown_function(fun):
print('Re-enabling dropout...')
if fun == test_dropout_graph_conv_activation or \
fun == test_multi_dgca:
decagon_pytorch.convolve.dropout, \
decagon_pytorch.convolve.dropout_sparse = \
setup_function.old_dropout
# def setup_function(fun):
# if fun == test_dropout_graph_conv_activation or \
# fun == test_multi_dgca:
# print('Disabling dropout for testing...')
# setup_function.old_dropout = decagon_pytorch.convolve.dropout, \
# decagon_pytorch.convolve.dropout_sparse
#
# decagon_pytorch.convolve.dropout = lambda x, keep_prob: x
# decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x
#
#
# def teardown_function(fun):
# print('Re-enabling dropout...')
# if fun == test_dropout_graph_conv_activation or \
# fun == test_multi_dgca:
# decagon_pytorch.convolve.dropout, \
# decagon_pytorch.convolve.dropout_sparse = \
# setup_function.old_dropout
def flexible_dropout_graph_conv_activation_torch(keep_prob=1.):
@@ -211,7 +211,20 @@ def flexible_dropout_graph_conv_activation_torch(keep_prob=1.):
return latent
def test_dropout_graph_conv_activation():
def _disable_dropout(monkeypatch):
monkeypatch.setattr(decagon_pytorch.convolve.dense, 'dropout',
lambda x, keep_prob: x)
monkeypatch.setattr(decagon_pytorch.convolve.sparse, 'dropout_sparse',
lambda x, keep_prob: x)
monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout',
lambda x, keep_prob: x)
monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout_sparse',
lambda x, keep_prob: x)
def test_dropout_graph_conv_activation(monkeypatch):
_disable_dropout(monkeypatch)
for i in range(11):
keep_prob = i/10.
if keep_prob == 0:
@@ -243,7 +256,9 @@ def test_dropout_graph_conv_activation():
assert np.all(latent_sparse[nonzero] == latent_flex[nonzero])
def test_multi_dgca():
def test_multi_dgca(monkeypatch):
_disable_dropout(monkeypatch)
keep_prob = .5
torch.random.manual_seed(0)


Loading…
Cancel
Save