IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

test_multi_dgca() passes.

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
1d271acab8
1 geänderte Dateien mit 11 neuen und 1 gelöschten Zeilen
  1. +11
    -1
      tests/decagon_pytorch/test_convolve.py

+ 11
- 1
tests/decagon_pytorch/test_convolve.py Datei anzeigen

@@ -227,16 +227,26 @@ def test_multi_dgca():
latent_sparse = torch.tensor(latent).to_sparse()
latent = torch.tensor(latent)
assert np.all(latent_sparse.to_dense().numpy() == latent.numpy())
adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ]
adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ]
for i in range(len(adjacency_matrices)):
assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy())
torch.random.manual_seed(0)
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob)
torch.random.manual_seed(0)
multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob)
for i in range(len(adjacency_matrices)):
assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy())
# torch.random.manual_seed(0)
latent_sparse = multi_sparse(latent_sparse)
# torch.random.manual_seed(0)
latent = multi(latent)
assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001)
assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())

Laden…
Abbrechen
Speichern