IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add test_sparse_dropout_grap_conv_activation()

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
2d92cc95bb
共有 2 个文件被更改,包括 66 次插入8 次删除
  1. +1
    -1
      src/decagon_pytorch/convolve.py
  2. +65
    -7
      tests/decagon_pytorch/test_convolve.py

+ 1
- 1
src/decagon_pytorch/convolve.py 查看文件

@@ -26,7 +26,7 @@ class SparseDropoutGraphConvActivation(torch.nn.Module):
activation=torch.nn.functional.relu,
**kwargs):
super().__init__(**kwargs)
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim)
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
self.keep_prob = keep_prob
self.activation = activation


+ 65
- 7
tests/decagon_pytorch/test_convolve.py 查看文件

@@ -19,6 +19,26 @@ def prepare_data():
return latent, adjacency_matrices
def dense_to_sparse_tf(x):
a, b = np.where(x)
indices = np.array([a, b]).T
values = x[a, b]
return tf.sparse.SparseTensor(indices, values, x.shape)
def dropout_sparse_tf(x, keep_prob, num_nonzero_elems):
"""Dropout for sparse tensors. Currently fails for very large sparse tensors (>1M elements)
"""
noise_shape = [num_nonzero_elems]
random_tensor = keep_prob
random_tensor += tf.convert_to_tensor(torch.rand(noise_shape).detach().numpy())
# tf.convert_to_tensor(np.random.random(noise_shape))
# tf.random_uniform(noise_shape)
dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
pre_out = tf.sparse_retain(x, dropout_mask)
return pre_out * (1./keep_prob)
def sparse_graph_conv_torch():
torch.random.manual_seed(0)
latent, adjacency_matrices = prepare_data()
@@ -34,24 +54,51 @@ def sparse_graph_conv_torch():
return latent
def dense_to_sparse_tf(x):
a, b = np.where(x)
indices = np.array([a, b]).T
values = x[a, b]
return tf.sparse.SparseTensor(indices, values, x.shape)
def sparse_graph_conv_tf():
torch.random.manual_seed(0)
latent, adjacency_matrices = prepare_data()
conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10,
torch.tensor(adjacency_matrices[0]).to_sparse())
weight = tf.constant(conv_torch.weight.detach().numpy())
latent = dense_to_sparse_tf(latent)
adj_mat = dense_to_sparse_tf(adjacency_matrices[0])
latent = tf.sparse_tensor_dense_matmul(latent, weight)
latent = tf.sparse_tensor_dense_matmul(adj_mat, latent)
return latent
def sparse_dropout_graph_conv_activation_torch(keep_prob=1.):
torch.random.manual_seed(0)
latent, adjacency_matrices = prepare_data()
latent = torch.tensor(latent).to_sparse()
adj_mat = adjacency_matrices[0]
adj_mat = torch.tensor(adj_mat).to_sparse()
conv = decagon_pytorch.convolve.SparseDropoutGraphConvActivation(10, 10,
adj_mat, keep_prob=keep_prob)
latent = conv(latent)
return latent
def sparse_graph_conv_tf():
def sparse_dropout_graph_conv_activation_tf(keep_prob=1.):
torch.random.manual_seed(0)
latent, adjacency_matrices = prepare_data()
conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10,
torch.tensor(adjacency_matrices[0]).to_sparse())
weight = tf.constant(conv_torch.weight.detach().numpy())
nonzero_feat = np.sum(latent > 0)
latent = dense_to_sparse_tf(latent)
latent = dropout_sparse_tf(latent, keep_prob,
nonzero_feat)
adj_mat = dense_to_sparse_tf(adjacency_matrices[0])
latent = tf.sparse_tensor_dense_matmul(latent, weight)
latent = tf.sparse_tensor_dense_matmul(adj_mat, latent)
latent = tf.nn.relu(latent)
return latent
@@ -62,7 +109,18 @@ def test_sparse_graph_conv():
def test_sparse_dropout_grap_conv_activation():
pass
for i in range(11):
keep_prob = i/10. + np.finfo(np.float32).eps
latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob)
latent_tf = sparse_dropout_graph_conv_activation_tf(keep_prob)
latent_torch = latent_torch.detach().numpy()
latent_tf = latent_tf.eval(session = tf.Session())
print('latent_torch:', latent_torch)
print('latent_tf:', latent_tf)
assert np.all(latent_torch - latent_tf < .000001)
def test_sparse_multi_dgca():


正在加载...
取消
保存