diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index d58c35f..4048521 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -301,3 +301,46 @@ class DenseMultiDGCA(torch.nn.Module): out += f(x[i]) out = torch.nn.functional.normalize(out, p=2, dim=1) return out + + +class GraphConv(torch.nn.Module): + """Convolution layer for sparse AND dense inputs.""" + def __init__(self, in_channels: int, out_channels: int, + adjacency_matrix: torch.Tensor, **kwargs) -> None: + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.sparse.mm(x, self.weight) \ + if x.is_sparse \ + else torch.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) \ + if self.adjacency_matrix.is_sparse \ + else torch.mm(self.adjacency_matrix, x) + return x + + +class DropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, + adjacency_matrix: torch.Tensor, keep_prob: float=1., + activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, + **kwargs) -> None: + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrix = adjacency_matrix + self.keep_prob = keep_prob + self.activation = activation + self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = dropout_sparse(x, self.keep_prob) \ + if x.is_sparse \ + else dropout(x, self.keep_prob) + x = self.graph_conv(x) + x = self.activation(x) + return x diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 302eec2..1f8b7a0 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -199,6 +199,18 @@ def teardown_function(fun): setup_function.old_dropout +def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + latent = torch.tensor(latent).to_sparse() + adj_mat = adjacency_matrices[0] + adj_mat = torch.tensor(adj_mat).to_sparse() + conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10, + adj_mat, keep_prob=keep_prob) + latent = conv(latent) + return latent + + def test_dropout_graph_conv_activation(): for i in range(11): keep_prob = i/10. @@ -214,10 +226,22 @@ def test_dropout_graph_conv_activation(): latent_sparse = latent_sparse.detach().numpy() print('latent_sparse:', latent_sparse) + latent_flex = flexible_dropout_graph_conv_activation_torch(keep_prob) + latent_flex = latent_flex.detach().numpy() + print('latent_flex:', latent_flex) + nonzero = (latent_dense != 0) & (latent_sparse != 0) assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) + nonzero = (latent_dense != 0) & (latent_flex != 0) + + assert np.all(latent_dense[nonzero] == latent_flex[nonzero]) + + nonzero = (latent_sparse != 0) & (latent_flex != 0) + + assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) + def test_multi_dgca(): keep_prob = .5