| @@ -124,4 +124,23 @@ def test_sparse_dropout_grap_conv_activation(): | |||||
| def test_sparse_multi_dgca(): | def test_sparse_multi_dgca(): | ||||
| pass | |||||
| latent_torch = None | |||||
| latent_tf = [] | |||||
| for i in range(11): | |||||
| keep_prob = i/10. + np.finfo(np.float32).eps | |||||
| latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob) \ | |||||
| if latent_torch is None \ | |||||
| else latent_torch + sparse_dropout_graph_conv_activation_torch(keep_prob) | |||||
| latent_tf.append(sparse_dropout_graph_conv_activation_tf(keep_prob)) | |||||
| latent_torch = torch.nn.functional.normalize(latent_torch, p=2, dim=1) | |||||
| latent_tf = tf.add_n(latent_tf) | |||||
| latent_tf = tf.nn.l2_normalize(latent_tf, dim=1) | |||||
| latent_torch = latent_torch.detach().numpy() | |||||
| latent_tf = latent_tf.eval(session = tf.Session()) | |||||
| assert np.all(latent_torch - latent_tf < .000001) | |||||