|  |  | @@ -301,3 +301,46 @@ class DenseMultiDGCA(torch.nn.Module): | 
		
	
		
			
			|  |  |  | out += f(x[i]) | 
		
	
		
			
			|  |  |  | out = torch.nn.functional.normalize(out, p=2, dim=1) | 
		
	
		
			
			|  |  |  | return out | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class GraphConv(torch.nn.Module): | 
		
	
		
			
			|  |  |  | """Convolution layer for sparse AND dense inputs.""" | 
		
	
		
			
			|  |  |  | def __init__(self, in_channels: int, out_channels: int, | 
		
	
		
			
			|  |  |  | adjacency_matrix: torch.Tensor, **kwargs) -> None: | 
		
	
		
			
			|  |  |  | super().__init__(**kwargs) | 
		
	
		
			
			|  |  |  | self.in_channels = in_channels | 
		
	
		
			
			|  |  |  | self.out_channels = out_channels | 
		
	
		
			
			|  |  |  | self.weight = init_glorot(in_channels, out_channels) | 
		
	
		
			
			|  |  |  | self.adjacency_matrix = adjacency_matrix | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | x = torch.sparse.mm(x, self.weight) \ | 
		
	
		
			
			|  |  |  | if x.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(x, self.weight) | 
		
	
		
			
			|  |  |  | x = torch.sparse.mm(self.adjacency_matrix, x) \ | 
		
	
		
			
			|  |  |  | if self.adjacency_matrix.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(self.adjacency_matrix, x) | 
		
	
		
			
			|  |  |  | return x | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class DropoutGraphConvActivation(torch.nn.Module): | 
		
	
		
			
			|  |  |  | def __init__(self, input_dim: int, output_dim: int, | 
		
	
		
			
			|  |  |  | adjacency_matrix: torch.Tensor, keep_prob: float=1., | 
		
	
		
			
			|  |  |  | activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | 
		
	
		
			
			|  |  |  | **kwargs) -> None: | 
		
	
		
			
			|  |  |  | super().__init__(**kwargs) | 
		
	
		
			
			|  |  |  | self.input_dim = input_dim | 
		
	
		
			
			|  |  |  | self.output_dim = output_dim | 
		
	
		
			
			|  |  |  | self.adjacency_matrix = adjacency_matrix | 
		
	
		
			
			|  |  |  | self.keep_prob = keep_prob | 
		
	
		
			
			|  |  |  | self.activation = activation | 
		
	
		
			
			|  |  |  | self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | x = dropout_sparse(x, self.keep_prob) \ | 
		
	
		
			
			|  |  |  | if x.is_sparse \ | 
		
	
		
			
			|  |  |  | else dropout(x, self.keep_prob) | 
		
	
		
			
			|  |  |  | x = self.graph_conv(x) | 
		
	
		
			
			|  |  |  | x = self.activation(x) | 
		
	
		
			
			|  |  |  | return x |