From 8f41021d69df22f2d4fe902db9fc5f8661102a00 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 12 May 2020 19:47:07 +0200 Subject: [PATCH] Add test_dropout_graph_conv_activation(). --- src/decagon_pytorch/convolve.py | 5 ++- src/decagon_pytorch/dropout.py | 13 +++++++ tests/decagon_pytorch/test_convolve.py | 52 +++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index 44252f3..c781e6b 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -1,5 +1,6 @@ import torch -from .dropout import dropout_sparse +from .dropout import dropout_sparse, \ + dropout from .weights import init_glorot @@ -79,7 +80,7 @@ class DropoutGraphConvActivation(torch.nn.Module): self.activation = activation def forward(self, x): - x = torch.nn.functional.dropout(x, 1.-self.keep_prob) + x = dropout(x, keep_prob=self.keep_prob) x = self.graph_conv(x) x = self.activation(x) return x diff --git a/src/decagon_pytorch/dropout.py b/src/decagon_pytorch/dropout.py index 3162572..27196fe 100644 --- a/src/decagon_pytorch/dropout.py +++ b/src/decagon_pytorch/dropout.py @@ -16,3 +16,16 @@ def dropout_sparse(x, keep_prob): x = torch.sparse_coo_tensor(i, v, size=size) return x * (1./keep_prob) + + +def dropout(x, keep_prob): + """Dropout for dense tensors. + """ + shape = x.shape + x = torch.flatten(x) + n = keep_prob + torch.rand(len(x)) + n = (1. - torch.floor(n)).to(torch.bool) + x[n] = 0 + x = torch.reshape(x, shape) + # x = torch.nn.functional.dropout(x, p=1.-keep_prob) + return x * (1./keep_prob) diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 3adb9fd..74998d0 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -16,6 +16,8 @@ def prepare_data(): adj_mat[adj_mat < .5] = 0 adj_mat = np.ceil(adj_mat) adjacency_matrices.append(adj_mat) + print('latent:', latent) + print('adjacency_matrices[0]:', adjacency_matrices[0]) return latent, adjacency_matrices @@ -51,6 +53,18 @@ def graph_conv_torch(): return latent +def dropout_graph_conv_activation_torch(keep_prob=1.): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + latent = torch.tensor(latent) + adj_mat = adjacency_matrices[0] + adj_mat = torch.tensor(adj_mat) + conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10, + adj_mat, keep_prob=keep_prob) + latent = conv(latent) + return latent + + def sparse_graph_conv_torch(): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() @@ -120,7 +134,7 @@ def test_sparse_graph_conv(): assert np.all(latent_torch.detach().numpy() == latent_tf.eval(session = tf.Session())) -def test_sparse_dropout_grap_conv_activation(): +def test_sparse_dropout_graph_conv_activation(): for i in range(11): keep_prob = i/10. + np.finfo(np.float32).eps @@ -163,3 +177,39 @@ def test_graph_conv(): latent_sparse = sparse_graph_conv_torch() assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy()) + + +def setup_function(fun): + if fun == test_dropout_graph_conv_activation: + setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ + decagon_pytorch.convolve.dropout_sparse + + decagon_pytorch.convolve.dropout = lambda x, keep_prob: x + decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x + + +def teardown_function(fun): + if fun == test_dropout_graph_conv_activation: + decagon_pytorch.convolve.dropout, \ + decagon_pytorch.convolve.dropout_sparse = \ + setup_function.old_dropout + + +def test_dropout_graph_conv_activation(): + for i in range(11): + keep_prob = i/10. + if keep_prob == 0: + keep_prob += np.finfo(np.float32).eps + print('keep_prob:', keep_prob) + + latent_dense = dropout_graph_conv_activation_torch(keep_prob) + latent_dense = latent_dense.detach().numpy() + print('latent_dense:', latent_dense) + + latent_sparse = sparse_dropout_graph_conv_activation_torch(keep_prob) + latent_sparse = latent_sparse.detach().numpy() + print('latent_sparse:', latent_sparse) + + nonzero = (latent_dense != 0) & (latent_sparse != 0) + + assert np.all(latent_dense[nonzero] == latent_sparse[nonzero])