|
|
@@ -124,4 +124,23 @@ def test_sparse_dropout_grap_conv_activation(): |
|
|
|
|
|
|
|
|
|
|
|
def test_sparse_multi_dgca():
|
|
|
|
pass
|
|
|
|
latent_torch = None
|
|
|
|
latent_tf = []
|
|
|
|
|
|
|
|
for i in range(11):
|
|
|
|
keep_prob = i/10. + np.finfo(np.float32).eps
|
|
|
|
|
|
|
|
latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob) \
|
|
|
|
if latent_torch is None \
|
|
|
|
else latent_torch + sparse_dropout_graph_conv_activation_torch(keep_prob)
|
|
|
|
|
|
|
|
latent_tf.append(sparse_dropout_graph_conv_activation_tf(keep_prob))
|
|
|
|
|
|
|
|
latent_torch = torch.nn.functional.normalize(latent_torch, p=2, dim=1)
|
|
|
|
latent_tf = tf.add_n(latent_tf)
|
|
|
|
latent_tf = tf.nn.l2_normalize(latent_tf, dim=1)
|
|
|
|
|
|
|
|
latent_torch = latent_torch.detach().numpy()
|
|
|
|
latent_tf = latent_tf.eval(session = tf.Session())
|
|
|
|
|
|
|
|
assert np.all(latent_torch - latent_tf < .000001)
|