@@ -1,2 +1,2 @@ | |||
__pycache__ | |||
.cache/ |
@@ -0,0 +1,53 @@ | |||
import torch | |||
from .dropout import dropout_sparse | |||
from .weights import init_glorot | |||
class SparseGraphConv(torch.nn.Module): | |||
"""Convolution layer for sparse inputs.""" | |||
def __init__(self, in_channels, out_channels, | |||
adjacency_matrix, **kwargs): | |||
super().__init__(**kwargs) | |||
self.in_channels = in_channels | |||
self.out_channels = out_channels | |||
self.weight = init_glorot(in_channels, out_channels) | |||
self.adjacency_matrix = adjacency_matrix | |||
def forward(self, x): | |||
x = torch.sparse.mm(x, self.weight) | |||
x = torch.sparse.mm(self.adjacency_matrix, x) | |||
return x | |||
class SparseDropoutGraphConvActivation(torch.nn.Module): | |||
def __init__(self, input_dim, output_dim, | |||
adjacency_matrix, keep_prob=1., | |||
activation=torch.nn.functional.relu, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim) | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
def forward(self, x): | |||
x = dropout_sparse(x, self.keep_prob) | |||
x = self.sparse_graph_conv(x) | |||
x = self.activation(x) | |||
return x | |||
class SparseMultiDGCA(torch.nn.Module): | |||
def __init__(self, input_dim, output_dim, | |||
adjacency_matrices, keep_prob=1., | |||
activation=torch.nn.functional.relu, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] | |||
def forward(self, x): | |||
out = torch.zeros(len(x), output_dim, dtype=x.dtype) | |||
for f in self.sparse_dgca: | |||
out += f(x) | |||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||
return out |
@@ -2,12 +2,12 @@ import torch | |||
import numpy as np | |||
def init_glorot(input_dim, output_dim): | |||
def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||
"""Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||
initialization. | |||
""" | |||
init_range = np.sqrt(6.0 / (input_dim + output_dim)) | |||
init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||
initial = -init_range + 2 * init_range * \ | |||
torch.rand(( input_dim, output_dim ), dtype=torch.float32) | |||
torch.rand(( in_channels, out_channels ), dtype=dtype) | |||
initial = initial.requires_grad_(True) | |||
return initial |
@@ -0,0 +1,69 @@ | |||
import decagon_pytorch.convolve | |||
import decagon.deep.layers | |||
import torch | |||
import tensorflow as tf | |||
import numpy as np | |||
def prepare_data(): | |||
np.random.seed(0) | |||
latent = np.random.random((5, 10)).astype(np.float32) | |||
latent[latent < .5] = 0 | |||
latent = np.ceil(latent) | |||
adjacency_matrices = [] | |||
for _ in range(5): | |||
adj_mat = np.random.random((len(latent),) * 2).astype(np.float32) | |||
adj_mat[adj_mat < .5] = 0 | |||
adj_mat = np.ceil(adj_mat) | |||
adjacency_matrices.append(adj_mat) | |||
return latent, adjacency_matrices | |||
def sparse_graph_conv_torch(): | |||
torch.random.manual_seed(0) | |||
latent, adjacency_matrices = prepare_data() | |||
print('latent.dtype:', latent.dtype) | |||
latent = torch.tensor(latent).to_sparse() | |||
adj_mat = adjacency_matrices[0] | |||
adj_mat = torch.tensor(adj_mat).to_sparse() | |||
print('adj_mat.dtype:', adj_mat.dtype, | |||
'latent.dtype:', latent.dtype) | |||
conv = decagon_pytorch.convolve.SparseGraphConv(10, 10, | |||
adj_mat) | |||
latent = conv(latent) | |||
return latent | |||
def dense_to_sparse_tf(x): | |||
a, b = np.where(x) | |||
indices = np.array([a, b]).T | |||
values = x[a, b] | |||
return tf.sparse.SparseTensor(indices, values, x.shape) | |||
def sparse_graph_conv_tf(): | |||
torch.random.manual_seed(0) | |||
latent, adjacency_matrices = prepare_data() | |||
conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10, | |||
torch.tensor(adjacency_matrices[0]).to_sparse()) | |||
weight = tf.constant(conv_torch.weight.detach().numpy()) | |||
latent = dense_to_sparse_tf(latent) | |||
adj_mat = dense_to_sparse_tf(adjacency_matrices[0]) | |||
latent = tf.sparse_tensor_dense_matmul(latent, weight) | |||
latent = tf.sparse_tensor_dense_matmul(adj_mat, latent) | |||
return latent | |||
def test_sparse_graph_conv(): | |||
latent_torch = sparse_graph_conv_torch() | |||
latent_tf = sparse_graph_conv_tf() | |||
assert np.all(latent_torch.detach().numpy() == latent_tf.eval(session = tf.Session())) | |||
def test_sparse_dropout_grap_conv_activation(): | |||
pass | |||
def test_sparse_multi_dgca(): | |||
pass |