From 2e5e1d69acd53d4aaf282cfcd1e8f0df112d5b19 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 22 May 2020 14:24:00 +0200 Subject: [PATCH] Add support for multiple input_dim in MultiDGCA. --- src/decagon_pytorch/convolve.py | 24 ++++++++++++++++++++---- tests/decagon_pytorch/test_convolve.py | 6 +++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index 4ff9ee8..10475c6 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -228,6 +228,8 @@ class SparseMultiDGCA(torch.nn.Module): self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + if not isinstance(x, list): + raise ValueError('x must be a list of tensors') out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) for i, f in enumerate(self.sparse_dgca): out += f(x[i]) @@ -273,12 +275,26 @@ class MultiDGCA(torch.nn.Module): activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) + self.input_dim = input_dim self.output_dim = output_dim - self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] + self.adjacency_matrices = adjacency_matrices + self.keep_prob = keep_prob + self.activation = activation + self.dgca = None + self.build() + + def build(self): + if len(self.input_dim) != len(self.adjacency_matrices): + raise ValueError('input_dim must have the same length as adjacency_matrices') + self.dgca = [] + for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): + self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: - out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) - for f in self.dgca: - out += f(x) + if not isinstance(x, list): + raise ValueError('x must be a list of tensors') + out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) + for i, f in enumerate(self.dgca): + out += f(x[i]) out = torch.nn.functional.normalize(out, p=2, dim=1) return out diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 71e1ece..b529bce 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -236,10 +236,10 @@ def test_multi_dgca(): assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy()) torch.random.manual_seed(0) - multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,]*len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob) + multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob) torch.random.manual_seed(0) - multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) + multi = decagon_pytorch.convolve.MultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices, keep_prob=keep_prob) print('len(adjacency_matrices):', len(adjacency_matrices)) print('len(multi_sparse.sparse_dgca):', len(multi_sparse.sparse_dgca)) @@ -251,6 +251,6 @@ def test_multi_dgca(): # torch.random.manual_seed(0) latent_sparse = multi_sparse([latent_sparse,] * len(adjacency_matrices)) # torch.random.manual_seed(0) - latent = multi(latent) + latent = multi([latent,] * len(adjacency_matrices)) assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())