diff --git a/.gitignore b/.gitignore index c20c2ab..837e16b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ __pycache__ - +.cache/ diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index e69de29..d125e98 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -0,0 +1,53 @@ +import torch +from .dropout import dropout_sparse +from .weights import init_glorot + + +class SparseGraphConv(torch.nn.Module): + """Convolution layer for sparse inputs.""" + def __init__(self, in_channels, out_channels, + adjacency_matrix, **kwargs): + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = init_glorot(in_channels, out_channels) + self.adjacency_matrix = adjacency_matrix + + + def forward(self, x): + x = torch.sparse.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) + return x + + +class SparseDropoutGraphConvActivation(torch.nn.Module): + def __init__(self, input_dim, output_dim, + adjacency_matrix, keep_prob=1., + activation=torch.nn.functional.relu, + **kwargs): + super().__init__(**kwargs) + self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim) + self.keep_prob = keep_prob + self.activation = activation + + def forward(self, x): + x = dropout_sparse(x, self.keep_prob) + x = self.sparse_graph_conv(x) + x = self.activation(x) + return x + + +class SparseMultiDGCA(torch.nn.Module): + def __init__(self, input_dim, output_dim, + adjacency_matrices, keep_prob=1., + activation=torch.nn.functional.relu, + **kwargs): + super().__init__(**kwargs) + self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] + + def forward(self, x): + out = torch.zeros(len(x), output_dim, dtype=x.dtype) + for f in self.sparse_dgca: + out += f(x) + out = torch.nn.functional.normalize(out, p=2, dim=1) + return out diff --git a/src/decagon_pytorch/weights.py b/src/decagon_pytorch/weights.py index 305a70f..cfab885 100644 --- a/src/decagon_pytorch/weights.py +++ b/src/decagon_pytorch/weights.py @@ -2,12 +2,12 @@ import torch import numpy as np -def init_glorot(input_dim, output_dim): +def init_glorot(in_channels, out_channels, dtype=torch.float32): """Create a weight variable with Glorot & Bengio (AISTATS 2010) initialization. """ - init_range = np.sqrt(6.0 / (input_dim + output_dim)) + init_range = np.sqrt(6.0 / (in_channels + out_channels)) initial = -init_range + 2 * init_range * \ - torch.rand(( input_dim, output_dim ), dtype=torch.float32) + torch.rand(( in_channels, out_channels ), dtype=dtype) initial = initial.requires_grad_(True) return initial diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py new file mode 100644 index 0000000..611f0a1 --- /dev/null +++ b/tests/decagon_pytorch/test_convolve.py @@ -0,0 +1,69 @@ +import decagon_pytorch.convolve +import decagon.deep.layers +import torch +import tensorflow as tf +import numpy as np + + +def prepare_data(): + np.random.seed(0) + latent = np.random.random((5, 10)).astype(np.float32) + latent[latent < .5] = 0 + latent = np.ceil(latent) + adjacency_matrices = [] + for _ in range(5): + adj_mat = np.random.random((len(latent),) * 2).astype(np.float32) + adj_mat[adj_mat < .5] = 0 + adj_mat = np.ceil(adj_mat) + adjacency_matrices.append(adj_mat) + return latent, adjacency_matrices + + +def sparse_graph_conv_torch(): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + print('latent.dtype:', latent.dtype) + latent = torch.tensor(latent).to_sparse() + adj_mat = adjacency_matrices[0] + adj_mat = torch.tensor(adj_mat).to_sparse() + print('adj_mat.dtype:', adj_mat.dtype, + 'latent.dtype:', latent.dtype) + conv = decagon_pytorch.convolve.SparseGraphConv(10, 10, + adj_mat) + latent = conv(latent) + return latent + + +def dense_to_sparse_tf(x): + a, b = np.where(x) + indices = np.array([a, b]).T + values = x[a, b] + return tf.sparse.SparseTensor(indices, values, x.shape) + + + +def sparse_graph_conv_tf(): + torch.random.manual_seed(0) + latent, adjacency_matrices = prepare_data() + conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10, + torch.tensor(adjacency_matrices[0]).to_sparse()) + weight = tf.constant(conv_torch.weight.detach().numpy()) + latent = dense_to_sparse_tf(latent) + adj_mat = dense_to_sparse_tf(adjacency_matrices[0]) + latent = tf.sparse_tensor_dense_matmul(latent, weight) + latent = tf.sparse_tensor_dense_matmul(adj_mat, latent) + return latent + + +def test_sparse_graph_conv(): + latent_torch = sparse_graph_conv_torch() + latent_tf = sparse_graph_conv_tf() + assert np.all(latent_torch.detach().numpy() == latent_tf.eval(session = tf.Session())) + + +def test_sparse_dropout_grap_conv_activation(): + pass + + +def test_sparse_multi_dgca(): + pass