@@ -1,2 +1,2 @@ | |||||
__pycache__ | __pycache__ | ||||
.cache/ |
@@ -0,0 +1,53 @@ | |||||
import torch | |||||
from .dropout import dropout_sparse | |||||
from .weights import init_glorot | |||||
class SparseGraphConv(torch.nn.Module): | |||||
"""Convolution layer for sparse inputs.""" | |||||
def __init__(self, in_channels, out_channels, | |||||
adjacency_matrix, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.adjacency_matrix = adjacency_matrix | |||||
def forward(self, x): | |||||
x = torch.sparse.mm(x, self.weight) | |||||
x = torch.sparse.mm(self.adjacency_matrix, x) | |||||
return x | |||||
class SparseDropoutGraphConvActivation(torch.nn.Module): | |||||
def __init__(self, input_dim, output_dim, | |||||
adjacency_matrix, keep_prob=1., | |||||
activation=torch.nn.functional.relu, | |||||
**kwargs): | |||||
super().__init__(**kwargs) | |||||
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim) | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def forward(self, x): | |||||
x = dropout_sparse(x, self.keep_prob) | |||||
x = self.sparse_graph_conv(x) | |||||
x = self.activation(x) | |||||
return x | |||||
class SparseMultiDGCA(torch.nn.Module): | |||||
def __init__(self, input_dim, output_dim, | |||||
adjacency_matrices, keep_prob=1., | |||||
activation=torch.nn.functional.relu, | |||||
**kwargs): | |||||
super().__init__(**kwargs) | |||||
self.sparse_dgca = [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] | |||||
def forward(self, x): | |||||
out = torch.zeros(len(x), output_dim, dtype=x.dtype) | |||||
for f in self.sparse_dgca: | |||||
out += f(x) | |||||
out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
return out |
@@ -2,12 +2,12 @@ import torch | |||||
import numpy as np | import numpy as np | ||||
def init_glorot(input_dim, output_dim): | |||||
def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||||
"""Create a weight variable with Glorot & Bengio (AISTATS 2010) | """Create a weight variable with Glorot & Bengio (AISTATS 2010) | ||||
initialization. | initialization. | ||||
""" | """ | ||||
init_range = np.sqrt(6.0 / (input_dim + output_dim)) | |||||
init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||||
initial = -init_range + 2 * init_range * \ | initial = -init_range + 2 * init_range * \ | ||||
torch.rand(( input_dim, output_dim ), dtype=torch.float32) | |||||
torch.rand(( in_channels, out_channels ), dtype=dtype) | |||||
initial = initial.requires_grad_(True) | initial = initial.requires_grad_(True) | ||||
return initial | return initial |
@@ -0,0 +1,69 @@ | |||||
import decagon_pytorch.convolve | |||||
import decagon.deep.layers | |||||
import torch | |||||
import tensorflow as tf | |||||
import numpy as np | |||||
def prepare_data(): | |||||
np.random.seed(0) | |||||
latent = np.random.random((5, 10)).astype(np.float32) | |||||
latent[latent < .5] = 0 | |||||
latent = np.ceil(latent) | |||||
adjacency_matrices = [] | |||||
for _ in range(5): | |||||
adj_mat = np.random.random((len(latent),) * 2).astype(np.float32) | |||||
adj_mat[adj_mat < .5] = 0 | |||||
adj_mat = np.ceil(adj_mat) | |||||
adjacency_matrices.append(adj_mat) | |||||
return latent, adjacency_matrices | |||||
def sparse_graph_conv_torch(): | |||||
torch.random.manual_seed(0) | |||||
latent, adjacency_matrices = prepare_data() | |||||
print('latent.dtype:', latent.dtype) | |||||
latent = torch.tensor(latent).to_sparse() | |||||
adj_mat = adjacency_matrices[0] | |||||
adj_mat = torch.tensor(adj_mat).to_sparse() | |||||
print('adj_mat.dtype:', adj_mat.dtype, | |||||
'latent.dtype:', latent.dtype) | |||||
conv = decagon_pytorch.convolve.SparseGraphConv(10, 10, | |||||
adj_mat) | |||||
latent = conv(latent) | |||||
return latent | |||||
def dense_to_sparse_tf(x): | |||||
a, b = np.where(x) | |||||
indices = np.array([a, b]).T | |||||
values = x[a, b] | |||||
return tf.sparse.SparseTensor(indices, values, x.shape) | |||||
def sparse_graph_conv_tf(): | |||||
torch.random.manual_seed(0) | |||||
latent, adjacency_matrices = prepare_data() | |||||
conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10, | |||||
torch.tensor(adjacency_matrices[0]).to_sparse()) | |||||
weight = tf.constant(conv_torch.weight.detach().numpy()) | |||||
latent = dense_to_sparse_tf(latent) | |||||
adj_mat = dense_to_sparse_tf(adjacency_matrices[0]) | |||||
latent = tf.sparse_tensor_dense_matmul(latent, weight) | |||||
latent = tf.sparse_tensor_dense_matmul(adj_mat, latent) | |||||
return latent | |||||
def test_sparse_graph_conv(): | |||||
latent_torch = sparse_graph_conv_torch() | |||||
latent_tf = sparse_graph_conv_tf() | |||||
assert np.all(latent_torch.detach().numpy() == latent_tf.eval(session = tf.Session())) | |||||
def test_sparse_dropout_grap_conv_activation(): | |||||
pass | |||||
def test_sparse_multi_dgca(): | |||||
pass |