From 99dbcdeb916b23348320a4568e43460a84c465d9 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 28 May 2020 11:59:43 +0200 Subject: [PATCH] Prefix dense versions of convolution classes with Dense. --- src/decagon_pytorch/convolve.py | 10 +++++----- tests/decagon_pytorch/test_convolve.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index 9087f6d..d58c35f 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -240,7 +240,7 @@ class SparseMultiDGCA(torch.nn.Module): return out -class GraphConv(torch.nn.Module): +class DenseGraphConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, adjacency_matrix: torch.Tensor, **kwargs) -> None: super().__init__(**kwargs) @@ -255,13 +255,13 @@ class GraphConv(torch.nn.Module): return x -class DropoutGraphConvActivation(torch.nn.Module): +class DenseDropoutGraphConvActivation(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, adjacency_matrix: torch.Tensor, keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) - self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) + self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix) self.keep_prob = keep_prob self.activation = activation @@ -272,7 +272,7 @@ class DropoutGraphConvActivation(torch.nn.Module): return x -class MultiDGCA(torch.nn.Module): +class DenseMultiDGCA(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: int, adjacency_matrices: List[torch.Tensor], keep_prob: float=1., activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, @@ -291,7 +291,7 @@ class MultiDGCA(torch.nn.Module): raise ValueError('input_dim must have the same length as adjacency_matrices') self.dgca = [] for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): - self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) + self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if not isinstance(x, list): diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index b529bce..302eec2 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -41,25 +41,25 @@ def dropout_sparse_tf(x, keep_prob, num_nonzero_elems): return pre_out * (1./keep_prob) -def graph_conv_torch(): +def dense_graph_conv_torch(): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() latent = torch.tensor(latent) adj_mat = adjacency_matrices[0] adj_mat = torch.tensor(adj_mat) - conv = decagon_pytorch.convolve.GraphConv(10, 10, + conv = decagon_pytorch.convolve.DenseGraphConv(10, 10, adj_mat) latent = conv(latent) return latent -def dropout_graph_conv_activation_torch(keep_prob=1.): +def dense_dropout_graph_conv_activation_torch(keep_prob=1.): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() latent = torch.tensor(latent) adj_mat = adjacency_matrices[0] adj_mat = torch.tensor(adj_mat) - conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10, + conv = decagon_pytorch.convolve.DenseDropoutGraphConvActivation(10, 10, adj_mat, keep_prob=keep_prob) latent = conv(latent) return latent @@ -173,7 +173,7 @@ def test_sparse_multi_dgca(): def test_graph_conv(): - latent_dense = graph_conv_torch() + latent_dense = dense_graph_conv_torch() latent_sparse = sparse_graph_conv_torch() assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy()) @@ -206,7 +206,7 @@ def test_dropout_graph_conv_activation(): keep_prob += np.finfo(np.float32).eps print('keep_prob:', keep_prob) - latent_dense = dropout_graph_conv_activation_torch(keep_prob) + latent_dense = dense_dropout_graph_conv_activation_torch(keep_prob) latent_dense = latent_dense.detach().numpy() print('latent_dense:', latent_dense) @@ -239,7 +239,7 @@ def test_multi_dgca(): multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob) torch.random.manual_seed(0) - multi = decagon_pytorch.convolve.MultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices, keep_prob=keep_prob) + multi = decagon_pytorch.convolve.DenseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices, keep_prob=keep_prob) print('len(adjacency_matrices):', len(adjacency_matrices)) print('len(multi_sparse.sparse_dgca):', len(multi_sparse.sparse_dgca))