|
|
@@ -180,7 +180,9 @@ def test_graph_conv(): |
|
|
|
|
|
|
|
|
|
|
|
def setup_function(fun):
|
|
|
|
if fun == test_dropout_graph_conv_activation:
|
|
|
|
if fun == test_dropout_graph_conv_activation or \
|
|
|
|
fun == test_multi_dgca:
|
|
|
|
print('Disabling dropout for testing...')
|
|
|
|
setup_function.old_dropout = decagon_pytorch.convolve.dropout, \
|
|
|
|
decagon_pytorch.convolve.dropout_sparse
|
|
|
|
|
|
|
@@ -189,7 +191,9 @@ def setup_function(fun): |
|
|
|
|
|
|
|
|
|
|
|
def teardown_function(fun):
|
|
|
|
if fun == test_dropout_graph_conv_activation:
|
|
|
|
print('Re-enabling dropout...')
|
|
|
|
if fun == test_dropout_graph_conv_activation or \
|
|
|
|
fun == test_multi_dgca:
|
|
|
|
decagon_pytorch.convolve.dropout, \
|
|
|
|
decagon_pytorch.convolve.dropout_sparse = \
|
|
|
|
setup_function.old_dropout
|
|
|
@@ -213,3 +217,26 @@ def test_dropout_graph_conv_activation(): |
|
|
|
nonzero = (latent_dense != 0) & (latent_sparse != 0)
|
|
|
|
|
|
|
|
assert np.all(latent_dense[nonzero] == latent_sparse[nonzero])
|
|
|
|
|
|
|
|
|
|
|
|
def test_multi_dgca():
|
|
|
|
keep_prob = .5
|
|
|
|
|
|
|
|
torch.random.manual_seed(0)
|
|
|
|
latent, adjacency_matrices = prepare_data()
|
|
|
|
|
|
|
|
latent_sparse = torch.tensor(latent).to_sparse()
|
|
|
|
latent = torch.tensor(latent)
|
|
|
|
|
|
|
|
adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ]
|
|
|
|
adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ]
|
|
|
|
|
|
|
|
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob)
|
|
|
|
multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob)
|
|
|
|
|
|
|
|
# torch.random.manual_seed(0)
|
|
|
|
latent_sparse = multi_sparse(latent_sparse)
|
|
|
|
# torch.random.manual_seed(0)
|
|
|
|
latent = multi(latent)
|
|
|
|
|
|
|
|
assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001)
|