diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index d125e98..aec9c04 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -26,7 +26,7 @@ class SparseDropoutGraphConvActivation(torch.nn.Module): activation=torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) - self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim) + self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) self.keep_prob = keep_prob self.activation = activation diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 611f0a1..2bf603a 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -19,6 +19,26 @@ def prepare_data(): return latent, adjacency_matrices +def dense_to_sparse_tf(x): + a, b = np.where(x) + indices = np.array([a, b]).T + values = x[a, b] + return tf.sparse.SparseTensor(indices, values, x.shape) + + +def dropout_sparse_tf(x, keep_prob, num_nonzero_elems): + """Dropout for sparse tensors. Currently fails for very large sparse tensors (>1M elements) + """ + noise_shape = [num_nonzero_elems] + random_tensor = keep_prob + random_tensor += tf.convert_to_tensor(torch.rand(noise_shape).detach().numpy()) + # tf.convert_to_tensor(np.random.random(noise_shape)) + # tf.random_uniform(noise_shape) + dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) + pre_out = tf.sparse_retain(x, dropout_mask) + return pre_out * (1./keep_prob) + + def sparse_graph_conv_torch(): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() @@ -34,24 +54,51 @@ def sparse_graph_conv_torch(): return latent -def dense_to_sparse_tf(x): - a, b = np.where(x) - indices = np.array([a, b]).T - values = x[a, b] - return tf.sparse.SparseTensor(indices, values, x.shape) +def sparse_graph_conv_tf(): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10, + torch.tensor(adjacency_matrices[0]).to_sparse()) + weight = tf.constant(conv_torch.weight.detach().numpy()) + latent = dense_to_sparse_tf(latent) + adj_mat = dense_to_sparse_tf(adjacency_matrices[0]) + latent = tf.sparse_tensor_dense_matmul(latent, weight) + latent = tf.sparse_tensor_dense_matmul(adj_mat, latent) + return latent +def sparse_dropout_graph_conv_activation_torch(keep_prob=1.): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + latent = torch.tensor(latent).to_sparse() + adj_mat = adjacency_matrices[0] + adj_mat = torch.tensor(adj_mat).to_sparse() + conv = decagon_pytorch.convolve.SparseDropoutGraphConvActivation(10, 10, + adj_mat, keep_prob=keep_prob) + latent = conv(latent) + return latent -def sparse_graph_conv_tf(): + +def sparse_dropout_graph_conv_activation_tf(keep_prob=1.): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10, torch.tensor(adjacency_matrices[0]).to_sparse()) + weight = tf.constant(conv_torch.weight.detach().numpy()) + nonzero_feat = np.sum(latent > 0) + latent = dense_to_sparse_tf(latent) + latent = dropout_sparse_tf(latent, keep_prob, + nonzero_feat) + adj_mat = dense_to_sparse_tf(adjacency_matrices[0]) + latent = tf.sparse_tensor_dense_matmul(latent, weight) latent = tf.sparse_tensor_dense_matmul(adj_mat, latent) + + latent = tf.nn.relu(latent) + return latent @@ -62,7 +109,18 @@ def test_sparse_graph_conv(): def test_sparse_dropout_grap_conv_activation(): - pass + for i in range(11): + keep_prob = i/10. + np.finfo(np.float32).eps + + latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob) + latent_tf = sparse_dropout_graph_conv_activation_tf(keep_prob) + + latent_torch = latent_torch.detach().numpy() + latent_tf = latent_tf.eval(session = tf.Session()) + print('latent_torch:', latent_torch) + print('latent_tf:', latent_tf) + + assert np.all(latent_torch - latent_tf < .000001) def test_sparse_multi_dgca():