@@ -301,3 +301,46 @@ class DenseMultiDGCA(torch.nn.Module): | |||||
out += f(x[i]) | out += f(x[i]) | ||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | out = torch.nn.functional.normalize(out, p=2, dim=1) | ||||
return out | return out | ||||
class GraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse AND dense inputs.""" | |||||
def __init__(self, in_channels: int, out_channels: int, | |||||
adjacency_matrix: torch.Tensor, **kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = torch.sparse.mm(x, self.weight) \ | |||||
if x.is_sparse \ | |||||
else torch.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||||
if self.adjacency_matrix.is_sparse \ | |||||
else torch.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class DropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim: int, output_dim: int, | |||||
adjacency_matrix: torch.Tensor, keep_prob: float=1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.adjacency_matrix = adjacency_matrix | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
x = dropout_sparse(x, self.keep_prob) \ | |||||
if x.is_sparse \ | |||||
else dropout(x, self.keep_prob) | |||||
x = self.graph_conv(x) | |||||
x = self.activation(x) | |||||
return x |
@@ -199,6 +199,18 @@ def teardown_function(fun): | |||||
setup_function.old_dropout | setup_function.old_dropout | ||||
def flexible_dropout_graph_conv_activation_torch(keep_prob=1.): | |||||
torch.random.manual_seed(0) | |||||
latent, adjacency_matrices = prepare_data() | |||||
latent = torch.tensor(latent).to_sparse() | |||||
adj_mat = adjacency_matrices[0] | |||||
adj_mat = torch.tensor(adj_mat).to_sparse() | |||||
conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10, | |||||
adj_mat, keep_prob=keep_prob) | |||||
latent = conv(latent) | |||||
return latent | |||||
def test_dropout_graph_conv_activation(): | def test_dropout_graph_conv_activation(): | ||||
for i in range(11): | for i in range(11): | ||||
keep_prob = i/10. | keep_prob = i/10. | ||||
@@ -214,10 +226,22 @@ def test_dropout_graph_conv_activation(): | |||||
latent_sparse = latent_sparse.detach().numpy() | latent_sparse = latent_sparse.detach().numpy() | ||||
print('latent_sparse:', latent_sparse) | print('latent_sparse:', latent_sparse) | ||||
latent_flex = flexible_dropout_graph_conv_activation_torch(keep_prob) | |||||
latent_flex = latent_flex.detach().numpy() | |||||
print('latent_flex:', latent_flex) | |||||
nonzero = (latent_dense != 0) & (latent_sparse != 0) | nonzero = (latent_dense != 0) & (latent_sparse != 0) | ||||
assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) | assert np.all(latent_dense[nonzero] == latent_sparse[nonzero]) | ||||
nonzero = (latent_dense != 0) & (latent_flex != 0) | |||||
assert np.all(latent_dense[nonzero] == latent_flex[nonzero]) | |||||
nonzero = (latent_sparse != 0) & (latent_flex != 0) | |||||
assert np.all(latent_sparse[nonzero] == latent_flex[nonzero]) | |||||
def test_multi_dgca(): | def test_multi_dgca(): | ||||
keep_prob = .5 | keep_prob = .5 | ||||