diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 2bf603a..1624a80 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -124,4 +124,23 @@ def test_sparse_dropout_grap_conv_activation(): def test_sparse_multi_dgca(): - pass + latent_torch = None + latent_tf = [] + + for i in range(11): + keep_prob = i/10. + np.finfo(np.float32).eps + + latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob) \ + if latent_torch is None \ + else latent_torch + sparse_dropout_graph_conv_activation_torch(keep_prob) + + latent_tf.append(sparse_dropout_graph_conv_activation_tf(keep_prob)) + + latent_torch = torch.nn.functional.normalize(latent_torch, p=2, dim=1) + latent_tf = tf.add_n(latent_tf) + latent_tf = tf.nn.l2_normalize(latent_tf, dim=1) + + latent_torch = latent_torch.detach().numpy() + latent_tf = latent_tf.eval(session = tf.Session()) + + assert np.all(latent_torch - latent_tf < .000001)