IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Sfoglia il codice sorgente

Need to de-bug test_multi_dgca().

master
Stanislaw Adaszewski 4 anni fa
parent
commit
f8cec09a09
2 ha cambiato i file con 31 aggiunte e 2 eliminazioni
  1. +2
    -0
      src/decagon_pytorch/convolve.py
  2. +29
    -2
      tests/decagon_pytorch/test_convolve.py

+ 2
- 0
src/decagon_pytorch/convolve.py Vedi File

@@ -44,6 +44,7 @@ class SparseMultiDGCA(torch.nn.Module):
activation=torch.nn.functional.relu,
**kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
def forward(self, x):
@@ -92,6 +93,7 @@ class MultiDGCA(torch.nn.Module):
activation=torch.nn.functional.relu,
**kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
def forward(self, x):


+ 29
- 2
tests/decagon_pytorch/test_convolve.py Vedi File

@@ -180,7 +180,9 @@ def test_graph_conv():
def setup_function(fun):
if fun == test_dropout_graph_conv_activation:
if fun == test_dropout_graph_conv_activation or \
fun == test_multi_dgca:
print('Disabling dropout for testing...')
setup_function.old_dropout = decagon_pytorch.convolve.dropout, \
decagon_pytorch.convolve.dropout_sparse
@@ -189,7 +191,9 @@ def setup_function(fun):
def teardown_function(fun):
if fun == test_dropout_graph_conv_activation:
print('Re-enabling dropout...')
if fun == test_dropout_graph_conv_activation or \
fun == test_multi_dgca:
decagon_pytorch.convolve.dropout, \
decagon_pytorch.convolve.dropout_sparse = \
setup_function.old_dropout
@@ -213,3 +217,26 @@ def test_dropout_graph_conv_activation():
nonzero = (latent_dense != 0) & (latent_sparse != 0)
assert np.all(latent_dense[nonzero] == latent_sparse[nonzero])
def test_multi_dgca():
keep_prob = .5
torch.random.manual_seed(0)
latent, adjacency_matrices = prepare_data()
latent_sparse = torch.tensor(latent).to_sparse()
latent = torch.tensor(latent)
adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ]
adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ]
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob)
multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob)
# torch.random.manual_seed(0)
latent_sparse = multi_sparse(latent_sparse)
# torch.random.manual_seed(0)
latent = multi(latent)
assert np.all(latent_sparse.detach().numpy() - latent.detach().numpy() < .000001)

Loading…
Annulla
Salva