| @@ -301,3 +301,46 @@ class DenseMultiDGCA(torch.nn.Module): | |||||
| out += f(x[i]) | out += f(x[i]) | ||||
| out = torch.nn.functional.normalize(out, p=2, dim=1) | out = torch.nn.functional.normalize(out, p=2, dim=1) | ||||
| return out | return out | ||||
| class GraphConv(torch.nn.Module): | |||||
| """Convolution layer for sparse AND dense inputs.""" | |||||
| def __init__(self, in_channels: int, out_channels: int, | |||||
| adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.weight = init_glorot(in_channels, out_channels) | |||||
| self.adjacency_matrix = adjacency_matrix | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| x = torch.sparse.mm(x, self.weight) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weight) | |||||
| x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||||
| if self.adjacency_matrix.is_sparse \ | |||||
| else torch.mm(self.adjacency_matrix, x) | |||||
| return x | |||||
| class DropoutGraphConvActivation(torch.nn.Module): | |||||
| def __init__(self, input_dim: int, output_dim: int, | |||||
| adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
| activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
| **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.adjacency_matrix = adjacency_matrix | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| x = dropout_sparse(x, self.keep_prob) \ | |||||
| if x.is_sparse \ | |||||
| else dropout(x, self.keep_prob) | |||||
| x = self.graph_conv(x) | |||||
| x = self.activation(x) | |||||
| return x | |||||
| @@ -199,6 +199,18 @@ def teardown_function(fun): | |||||
| setup_function.old_dropout | setup_function.old_dropout | ||||
| def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): | |||||
| torch.random.manual_seed(0) | |||||
| latent, adjacency_matrices = prepare_data() | |||||
| latent = torch.tensor(latent).to_sparse() | |||||
| adj_mat = adjacency_matrices[0] | |||||
| adj_mat = torch.tensor(adj_mat).to_sparse() | |||||
| conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10, | |||||
| adj_mat, keep_prob=keep_prob) | |||||
| latent = conv(latent) | |||||
| return latent | |||||
| def test_dropout_graph_conv_activation(): | def test_dropout_graph_conv_activation(): | ||||
| for i in range(11): | for i in range(11): | ||||
| keep_prob = i/10. | keep_prob = i/10. | ||||
| @@ -214,10 +226,22 @@ def test_dropout_graph_conv_activation(): | |||||
| latent_sparse = latent_sparse.detach().numpy() | latent_sparse = latent_sparse.detach().numpy() | ||||
| print('latent_sparse:', latent_sparse) | print('latent_sparse:', latent_sparse) | ||||
| latent_flex = flexible_dropout_graph_conv_activation_torch(keep_prob) | |||||
| latent_flex = latent_flex.detach().numpy() | |||||
| print('latent_flex:', latent_flex) | |||||
| nonzero = (latent_dense != 0) & (latent_sparse != 0) | nonzero = (latent_dense != 0) & (latent_sparse != 0) | ||||
| assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) | assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) | ||||
| nonzero = (latent_dense != 0) & (latent_flex != 0) | |||||
| assert np.all(latent_dense[nonzero] == latent_flex[nonzero]) | |||||
| nonzero = (latent_sparse != 0) & (latent_flex != 0) | |||||
| assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) | |||||
| def test_multi_dgca(): | def test_multi_dgca(): | ||||
| keep_prob = .5 | keep_prob = .5 | ||||