| @@ -212,13 +212,25 @@ class SparseMultiDGCA(torch.nn.Module): | |||||
| activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | ||||
| **kwargs) -> None: | **kwargs) -> None: | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | self.output_dim = output_dim | ||||
| self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] | |||||
| def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
| out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) | |||||
| for f in self.sparse_dgca: | |||||
| out += f(x) | |||||
| self.adjacency_matrices = adjacency_matrices | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.sparse_dgca = None | |||||
| self.build() | |||||
| def build(self): | |||||
| if len(self.input_dim) != len(self.adjacency_matrices): | |||||
| raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
| self.sparse_dgca = [] | |||||
| for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
| self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
| def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |||||
| out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
| for i, f in enumerate(self.sparse_dgca): | |||||
| out += f(x[i]) | |||||
| out = torch.nn.functional.normalize(out, p=2, dim=1) | out = torch.nn.functional.normalize(out, p=2, dim=1) | ||||
| return out | return out | ||||
| @@ -236,16 +236,20 @@ def test_multi_dgca(): | |||||
| assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy()) | assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy()) | ||||
| torch.random.manual_seed(0) | torch.random.manual_seed(0) | ||||
| multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA(10, 10, adjacency_matrices_sparse, keep_prob=keep_prob) | |||||
| multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,]*len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob) | |||||
| torch.random.manual_seed(0) | torch.random.manual_seed(0) | ||||
| multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) | multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob) | ||||
| print('len(adjacency_matrices):', len(adjacency_matrices)) | |||||
| print('len(multi_sparse.sparse_dgca):', len(multi_sparse.sparse_dgca)) | |||||
| print('len(multi.dgca):', len(multi.dgca)) | |||||
| for i in range(len(adjacency_matrices)): | for i in range(len(adjacency_matrices)): | ||||
| assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy()) | assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy()) | ||||
| # torch.random.manual_seed(0) | # torch.random.manual_seed(0) | ||||
| latent_sparse = multi_sparse(latent_sparse) | |||||
| latent_sparse = multi_sparse([latent_sparse,] * len(adjacency_matrices)) | |||||
| # torch.random.manual_seed(0) | # torch.random.manual_seed(0) | ||||
| latent = multi(latent) | latent = multi(latent) | ||||