|  |  | @@ -5,6 +5,37 @@ from .data import Data | 
		
	
		
			
			|  |  |  | from .trainprep import PreparedData | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | from .weights import init_glorot | 
		
	
		
			
			|  |  |  | from .normalize import _sparse_coo_tensor | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _sparse_diag_cat(matrices: List[torch.Tensor]): | 
		
	
		
			
			|  |  |  | if len(matrices) == 0: | 
		
	
		
			
			|  |  |  | raise ValueError('The list of matrices must be non-empty') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not all(m.is_sparse for m in matrices): | 
		
	
		
			
			|  |  |  | raise ValueError('All matrices must be sparse') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not all(len(m.shape) == 2 for m in matrices): | 
		
	
		
			
			|  |  |  | raise ValueError('All matrices must be 2D') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | indices = [] | 
		
	
		
			
			|  |  |  | values = [] | 
		
	
		
			
			|  |  |  | row_offset = 0 | 
		
	
		
			
			|  |  |  | col_offset = 0 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for m in matrices: | 
		
	
		
			
			|  |  |  | ind = m._indices().clone() | 
		
	
		
			
			|  |  |  | ind[0] += row_offset | 
		
	
		
			
			|  |  |  | ind[1] += col_offset | 
		
	
		
			
			|  |  |  | indices.append(ind) | 
		
	
		
			
			|  |  |  | values.append(m._values()) | 
		
	
		
			
			|  |  |  | row_offset += m.shape[0] | 
		
	
		
			
			|  |  |  | col_offset += m.shape[1] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | indices = torch.cat(indices, dim=1) | 
		
	
		
			
			|  |  |  | values = torch.cat(values) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _cat(matrices: List[torch.Tensor]): | 
		
	
	
		
			
				|  |  | @@ -79,7 +110,7 @@ class FastConvLayer(torch.nn.Module): | 
		
	
		
			
			|  |  |  | output_dim: List[int], | 
		
	
		
			
			|  |  |  | data: Union[Data, PreparedData], | 
		
	
		
			
			|  |  |  | keep_prob: float = 1., | 
		
	
		
			
			|  |  |  | rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x | 
		
	
		
			
			|  |  |  | rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | 
		
	
		
			
			|  |  |  | layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | 
		
	
		
			
			|  |  |  | **kwargs): | 
		
	
		
			
			|  |  |  |  | 
		
	
	
		
			
				|  |  | 
 |