| 
				
				
				
				 | 
			
			 | 
			@@ -1,104 +1,272 @@ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import torch
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .dropout import dropout_sparse, \
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    dropout
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .weights import init_glorot
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseGraphConv(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    """Convolution layer for sparse inputs."""
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, in_channels, out_channels,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix, **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.in_channels = in_channels
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.out_channels = out_channels
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.weight = init_glorot(in_channels, out_channels)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.adjacency_matrix = adjacency_matrix
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.sparse.mm(x, self.weight)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.sparse.mm(self.adjacency_matrix, x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseDropoutGraphConvActivation(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim, output_dim,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix, keep_prob=1.,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation=torch.nn.functional.relu,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.keep_prob = keep_prob
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.activation = activation
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = dropout_sparse(x, self.keep_prob)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.sparse_graph_conv(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.activation(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseMultiDGCA(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim, output_dim,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrices, keep_prob=1.,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation=torch.nn.functional.relu,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.output_dim = output_dim
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.sparse_dgca =  [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for f in self.sparse_dgca:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            out += f(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.nn.functional.normalize(out, p=2, dim=1)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return out
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class GraphConv(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, in_channels, out_channels,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix, **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.in_channels = in_channels
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.out_channels = out_channels
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.weight = init_glorot(in_channels, out_channels)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.adjacency_matrix = adjacency_matrix
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.mm(x, self.weight)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.mm(self.adjacency_matrix, x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class DropoutGraphConvActivation(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim, output_dim,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix, keep_prob=1.,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation=torch.nn.functional.relu,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.keep_prob = keep_prob
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.activation = activation
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = dropout(x, keep_prob=self.keep_prob)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.graph_conv(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.activation(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class MultiDGCA(torch.nn.Module):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim, output_dim,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrices, keep_prob=1.,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation=torch.nn.functional.relu,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.output_dim = output_dim
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.dgca =  [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for f in self.dgca:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            out += f(x)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.nn.functional.normalize(out, p=2, dim=1)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return out
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			# | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			# Copyright (C) Stanislaw Adaszewski, 2020 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			# License: GPLv3 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			# | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			""" | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			This module implements the basic convolutional blocks of Decagon. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			Just as a quick reminder, the basic convolution formula here is: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			y = A * (x * W) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			where: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			W is a weight matrix | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A is an adjacency matrix | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			x is a matrix of latent representations of a particular type of neighbors. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			As we have x here twice, a trick is obviously necessary for this to work. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A must be previously normalized with: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			or | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			c_{r}^{i} = 1/|N_{r}^{i}| | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			Let's work through this step by step to convince ourselves that the | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			formula is correct. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			x = [ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1, 0, 1], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1, 1, 1, 0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 0, 0, 1] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			W = [ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1, 0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.5, 0.5], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.25, 0.75] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = [ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1, 0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1, 0, 1], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1, 0] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			so the graph looks like this: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			(0) -- (1) -- (2) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			and therefore the representations in the next layer should be: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k} | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} + | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    c_{r}^{1} * h_{1}^{k} | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k} | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			In actual Decagon code we can see that that latter part propagating directly | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			the old representation is gone. I will try to do the same for now. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			So we have to only take care of: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A + eye(len(A)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			rowsum = A.sum(1) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = dot(A, deg_mat_inv_sqrt) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A.transpose() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A.dot(deg_mat_inv_sqrt) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			Let's see what gives in our case: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A + eye(len(A)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1, 1, 0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1, 1, 1], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1, 1] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			rowsum = A.sum(1) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[2, 3, 2] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			deg_mat_inv_sqrt = diags(power(rowsum, -0.5)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1./sqrt(2), 0,  0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0, 1./sqrt(3),  0], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0,  0, 1./sqrt(2)] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = dot(A, deg_mat_inv_sqrt) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(2), 1/sqrt(3),         0 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [         0, 1/sqrt(3), 1/sqrt(2) ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A.transpose() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(2), 1/sqrt(2),         0 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [         0, 1/sqrt(2), 1/sqrt(2) ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			A = A.dot(deg_mat_inv_sqrt) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(2) * 1/sqrt(2),   1/sqrt(2) * 1/sqrt(3),                       0 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1/sqrt(3) * 1/sqrt(2),   1/sqrt(3) * 1/sqrt(3),   1/sqrt(3) * 1/sqrt(2) ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [                     0,   1/sqrt(2) * 1/sqrt(3),   1/sqrt(2) * 1/sqrt(2) ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			thus: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.5       , 0.40824829, 0.        ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.40824829, 0.33333333, 0.40824829], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.        , 0.40824829, 0.5       ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			Then, we get back to the main calculation: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			y = x * W | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			y = A * y | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			y = x * W | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1.25, 0.75 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 1.5 , 1.5  ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [ 0.25, 0.75 ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			y = A * y | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			that is: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			[ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1.23737243, 0.98737244], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [1.11237243, 1.11237243], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    [0.73737244, 0.98737244] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			]. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			All checks out nicely, good. | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			""" | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			import torch | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .dropout import dropout_sparse, \ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    dropout | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .weights import init_glorot | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from typing import List, Callable | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseGraphConv(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    """Convolution layer for sparse inputs.""" | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, in_channels: int, out_channels: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix: torch.Tensor, **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.in_channels = in_channels | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.out_channels = out_channels | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.weight = init_glorot(in_channels, out_channels) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.adjacency_matrix = adjacency_matrix | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.sparse.mm(x, self.weight) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.sparse.mm(self.adjacency_matrix, x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseDropoutGraphConvActivation(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim: int, output_dim: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix: torch.Tensor, keep_prob: float=1., | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.keep_prob = keep_prob | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.activation = activation | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = dropout_sparse(x, self.keep_prob) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.sparse_graph_conv(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.activation(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class SparseMultiDGCA(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim: List[int], output_dim: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.output_dim = output_dim | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.sparse_dgca =  [ SparseDropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for f in self.sparse_dgca: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            out += f(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.nn.functional.normalize(out, p=2, dim=1) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return out | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class GraphConv(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, in_channels: int, out_channels: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix: torch.Tensor, **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.in_channels = in_channels | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.out_channels = out_channels | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.weight = init_glorot(in_channels, out_channels) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.adjacency_matrix = adjacency_matrix | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.mm(x, self.weight) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = torch.mm(self.adjacency_matrix, x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class DropoutGraphConvActivation(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim: int, output_dim: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrix: torch.Tensor, keep_prob: float=1., | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.keep_prob = keep_prob | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.activation = activation | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = dropout(x, keep_prob=self.keep_prob) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.graph_conv(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        x = self.activation(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return x | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class MultiDGCA(torch.nn.Module): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, input_dim: List[int], output_dim: int, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        **kwargs) -> None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        super().__init__(**kwargs) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.output_dim = output_dim | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.dgca =  [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for f in self.dgca: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            out += f(x) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        out = torch.nn.functional.normalize(out, p=2, dim=1) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return out |