|
|
@@ -227,16 +227,26 @@ def test_multi_dgca(): |
|
|
|
|
|
|
|
latent_sparse = torch.tensor(latent).to_sparse()
|
|
|
|
latent = torch.tensor(latent)
|
|
|
|
assert np.all(latent_sparse.to_dense().numpy() == latent.numpy())
|
|
|
|
|
|
|
|
adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ]
|
|
|
|
adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ]
|
|
|
|
|
|
|
|
for i in range(len(adjacency_matrices)):
|
|
|
|
assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy())
|
|
|
|
|
|
|
|
torch.random.manual_seed(0)
|
|
|
|
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob)
|
|
|
|
|
|
|
|
torch.random.manual_seed(0)
|
|
|
|
multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob)
|
|
|
|
|
|
|
|
for i in range(len(adjacency_matrices)):
|
|
|
|
assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy())
|
|
|
|
|
|
|
|
# torch.random.manual_seed(0)
|
|
|
|
latent_sparse = multi_sparse(latent_sparse)
|
|
|
|
# torch.random.manual_seed(0)
|
|
|
|
latent = multi(latent)
|
|
|
|
|
|
|
|
assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001)
|
|
|
|
assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())
|