|  |  | @@ -23,7 +23,9 @@ | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | from .convolve import SparseMultiDGCA | 
		
	
		
			
			|  |  |  | from .data import Data | 
		
	
		
			
			|  |  |  | from typing import List, Union | 
		
	
		
			
			|  |  |  | from typing import List, \ | 
		
	
		
			
			|  |  |  | Union, \ | 
		
	
		
			
			|  |  |  | Callable | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class Layer(torch.nn.Module): | 
		
	
	
		
			
				|  |  | @@ -65,15 +67,20 @@ class InputLayer(Layer): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class DecagonLayer(Layer): | 
		
	
		
			
			|  |  |  | def __init__(self, data: Data, | 
		
	
		
			
			|  |  |  | input_dim, output_dim, | 
		
	
		
			
			|  |  |  | keep_prob=1., | 
		
	
		
			
			|  |  |  | rel_activation=lambda x: x, | 
		
	
		
			
			|  |  |  | layer_activation=torch.nn.functional.relu, | 
		
	
		
			
			|  |  |  | def __init__(self, | 
		
	
		
			
			|  |  |  | data: Data, | 
		
	
		
			
			|  |  |  | previous_layer: Layer, | 
		
	
		
			
			|  |  |  | output_dim: Union[int, List[int]], | 
		
	
		
			
			|  |  |  | keep_prob: float = 1., | 
		
	
		
			
			|  |  |  | rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | 
		
	
		
			
			|  |  |  | layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | 
		
	
		
			
			|  |  |  | **kwargs): | 
		
	
		
			
			|  |  |  | if not isinstance(output_dim, list): | 
		
	
		
			
			|  |  |  | output_dim = [ output_dim ] * len(data.node_types) | 
		
	
		
			
			|  |  |  | super().__init__(output_dim, **kwargs) | 
		
	
		
			
			|  |  |  | self.data = data | 
		
	
		
			
			|  |  |  | self.input_dim = input_dim | 
		
	
		
			
			|  |  |  | self.previous_layer = previous_layer | 
		
	
		
			
			|  |  |  | self.input_dim = previous_layer.output_dim | 
		
	
		
			
			|  |  |  | self.keep_prob = keep_prob | 
		
	
		
			
			|  |  |  | self.rel_activation = rel_activation | 
		
	
		
			
			|  |  |  | self.layer_activation = layer_activation | 
		
	
	
		
			
				|  |  | 
 |