|  | @@ -179,24 +179,24 @@ def test_graph_conv(): | 
														
													
														
															
																|  |  | assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy()) |  |  | assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy()) | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | def setup_function(fun): |  |  |  | 
														
													
														
															
																|  |  | if fun == test_dropout_graph_conv_activation or \ |  |  |  | 
														
													
														
															
																|  |  | fun == test_multi_dgca: |  |  |  | 
														
													
														
															
																|  |  | print('Disabling dropout for testing...') |  |  |  | 
														
													
														
															
																|  |  | setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ |  |  |  | 
														
													
														
															
																|  |  | decagon_pytorch.convolve.dropout_sparse |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | decagon_pytorch.convolve.dropout = lambda x, keep_prob: x |  |  |  | 
														
													
														
															
																|  |  | decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | def teardown_function(fun): |  |  |  | 
														
													
														
															
																|  |  | print('Re-enabling dropout...') |  |  |  | 
														
													
														
															
																|  |  | if fun == test_dropout_graph_conv_activation or \ |  |  |  | 
														
													
														
															
																|  |  | fun == test_multi_dgca: |  |  |  | 
														
													
														
															
																|  |  | decagon_pytorch.convolve.dropout, \ |  |  |  | 
														
													
														
															
																|  |  | decagon_pytorch.convolve.dropout_sparse = \ |  |  |  | 
														
													
														
															
																|  |  | setup_function.old_dropout |  |  |  | 
														
													
														
															
																|  |  |  |  |  | # def setup_function(fun): | 
														
													
														
															
																|  |  |  |  |  | #     if fun == test_dropout_graph_conv_activation or \ | 
														
													
														
															
																|  |  |  |  |  | #         fun == test_multi_dgca: | 
														
													
														
															
																|  |  |  |  |  | #         print('Disabling dropout for testing...') | 
														
													
														
															
																|  |  |  |  |  | #         setup_function.old_dropout = decagon_pytorch.convolve.dropout, \ | 
														
													
														
															
																|  |  |  |  |  | #             decagon_pytorch.convolve.dropout_sparse | 
														
													
														
															
																|  |  |  |  |  | # | 
														
													
														
															
																|  |  |  |  |  | #         decagon_pytorch.convolve.dropout = lambda x, keep_prob: x | 
														
													
														
															
																|  |  |  |  |  | #         decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x | 
														
													
														
															
																|  |  |  |  |  | # | 
														
													
														
															
																|  |  |  |  |  | # | 
														
													
														
															
																|  |  |  |  |  | # def teardown_function(fun): | 
														
													
														
															
																|  |  |  |  |  | #     print('Re-enabling dropout...') | 
														
													
														
															
																|  |  |  |  |  | #     if fun == test_dropout_graph_conv_activation or \ | 
														
													
														
															
																|  |  |  |  |  | #         fun == test_multi_dgca: | 
														
													
														
															
																|  |  |  |  |  | #         decagon_pytorch.convolve.dropout, \ | 
														
													
														
															
																|  |  |  |  |  | #             decagon_pytorch.convolve.dropout_sparse = \ | 
														
													
														
															
																|  |  |  |  |  | #             setup_function.old_dropout | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): |  |  | def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): | 
														
													
												
													
														
															
																|  | @@ -211,7 +211,20 @@ def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): | 
														
													
														
															
																|  |  | return latent |  |  | return latent | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | def test_dropout_graph_conv_activation(): |  |  |  | 
														
													
														
															
																|  |  |  |  |  | def _disable_dropout(monkeypatch): | 
														
													
														
															
																|  |  |  |  |  | monkeypatch.setattr(decagon_pytorch.convolve.dense, 'dropout', | 
														
													
														
															
																|  |  |  |  |  | lambda x, keep_prob: x) | 
														
													
														
															
																|  |  |  |  |  | monkeypatch.setattr(decagon_pytorch.convolve.sparse, 'dropout_sparse', | 
														
													
														
															
																|  |  |  |  |  | lambda x, keep_prob: x) | 
														
													
														
															
																|  |  |  |  |  | monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout', | 
														
													
														
															
																|  |  |  |  |  | lambda x, keep_prob: x) | 
														
													
														
															
																|  |  |  |  |  | monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout_sparse', | 
														
													
														
															
																|  |  |  |  |  | lambda x, keep_prob: x) | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  | def test_dropout_graph_conv_activation(monkeypatch): | 
														
													
														
															
																|  |  |  |  |  | _disable_dropout(monkeypatch) | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | for i in range(11): |  |  | for i in range(11): | 
														
													
														
															
																|  |  | keep_prob = i/10. |  |  | keep_prob = i/10. | 
														
													
														
															
																|  |  | if keep_prob == 0: |  |  | if keep_prob == 0: | 
														
													
												
													
														
															
																|  | @@ -243,7 +256,9 @@ def test_dropout_graph_conv_activation(): | 
														
													
														
															
																|  |  | assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) |  |  | assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | def test_multi_dgca(): |  |  |  | 
														
													
														
															
																|  |  |  |  |  | def test_multi_dgca(monkeypatch): | 
														
													
														
															
																|  |  |  |  |  | _disable_dropout(monkeypatch) | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | keep_prob = .5 |  |  | keep_prob = .5 | 
														
													
														
															
																|  |  |  |  |  |  | 
														
													
														
															
																|  |  | torch.random.manual_seed(0) |  |  | torch.random.manual_seed(0) | 
														
													
												
													
														
															
																|  | 
 |