From 1d271acab8b633fa0db1e14e2e48505eb2a1bb17 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 12 May 2020 22:32:07 +0200 Subject: [PATCH] test_multi_dgca() passes. --- tests/decagon_pytorch/test_convolve.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 1e2c538..210a4e5 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -227,16 +227,26 @@ def test_multi_dgca(): latent_sparse = torch.tensor(latent).to_sparse() latent = torch.tensor(latent) + assert np.all(latent_sparse.to_dense().numpy() == latent.numpy()) adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ] adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ] + for i in range(len(adjacency_matrices)): + assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy()) + + torch.random.manual_seed(0) multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob) + + torch.random.manual_seed(0) multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) + for i in range(len(adjacency_matrices)): + assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy()) + # torch.random.manual_seed(0) latent_sparse = multi_sparse(latent_sparse) # torch.random.manual_seed(0) latent = multi(latent) - assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001) + assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())