diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index aec9c04..44252f3 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -46,8 +46,56 @@ class SparseMultiDGCA(torch.nn.Module): self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] def forward(self, x): - out = torch.zeros(len(x), output_dim, dtype=x.dtype) + out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) for f in self.sparse_dgca: out += f(x) out = torch.nn.functional.normalize(out, p=2, dim=1) return out + + +class GraphConv(torch.nn.Module): + def __init__(self, in_channels, out_channels, + adjacency_matrix, **kwargs): + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + def forward(self, x): + x = torch.mm(x, self.weight) + x = torch.mm(self.adjacency_matrix, x) + return x + + +class DropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim, output_dim, + adjacency_matrix, keep_prob=1., + activation=torch.nn.functional.relu, + **kwargs): + super().__init__(**kwargs) + self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) + self.keep_prob = keep_prob + self.activation = activation + + def forward(self, x): + x = torch.nn.functional.dropout(x, 1.-self.keep_prob) + x = self.graph_conv(x) + x = self.activation(x) + return x + + +class MultiDGCA(torch.nn.Module): + def __init__(self, input_dim, output_dim, + adjacency_matrices, keep_prob=1., + activation=torch.nn.functional.relu, + **kwargs): + super().__init__(**kwargs) + self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] + + def forward(self, x): + out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) + for f in self.dgca: + out += f(x) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py index 1624a80..3adb9fd 100644 --- a/tests/decagon_pytorch/test_convolve.py +++ b/tests/decagon_pytorch/test_convolve.py @@ -39,6 +39,18 @@ def dropout_sparse_tf(x, keep_prob, num_nonzero_elems): return pre_out * (1./keep_prob) +def graph_conv_torch(): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + latent = torch.tensor(latent) + adj_mat = adjacency_matrices[0] + adj_mat = torch.tensor(adj_mat) + conv = decagon_pytorch.convolve.GraphConv(10, 10, + adj_mat) + latent = conv(latent) + return latent + + def sparse_graph_conv_torch(): torch.random.manual_seed(0) latent, adjacency_matrices = prepare_data() @@ -144,3 +156,10 @@ def test_sparse_multi_dgca(): latent_tf = latent_tf.eval(session = tf.Session()) assert np.all(latent_torch - latent_tf < .000001) + + +def test_graph_conv(): + latent_dense = graph_conv_torch() + latent_sparse = sparse_graph_conv_torch() + + assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy())