| @@ -227,16 +227,26 @@ def test_multi_dgca(): | |||||
| latent_sparse = torch.tensor(latent).to_sparse() | latent_sparse = torch.tensor(latent).to_sparse() | ||||
| latent = torch.tensor(latent) | latent = torch.tensor(latent) | ||||
| assert np.all(latent_sparse.to_dense().numpy() == latent.numpy()) | |||||
| adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ] | adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ] | ||||
| adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ] | adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ] | ||||
| for i in range(len(adjacency_matrices)): | |||||
| assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy()) | |||||
| torch.random.manual_seed(0) | |||||
| multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob) | multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob) | ||||
| torch.random.manual_seed(0) | |||||
| multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) | multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) | ||||
| for i in range(len(adjacency_matrices)): | |||||
| assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy()) | |||||
| # torch.random.manual_seed(0) | # torch.random.manual_seed(0) | ||||
| latent_sparse = multi_sparse(latent_sparse) | latent_sparse = multi_sparse(latent_sparse) | ||||
| # torch.random.manual_seed(0) | # torch.random.manual_seed(0) | ||||
| latent = multi(latent) | latent = multi(latent) | ||||
| assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001) | |||||
| assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy()) | |||||