|  |  | @@ -60,11 +60,11 @@ class OneHotInputLayer(torch.nn.Module): | 
		
	
		
			
			|  |  |  | self.build() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def build(self) -> None: | 
		
	
		
			
			|  |  |  | self.node_reps = [] | 
		
	
		
			
			|  |  |  | self.node_reps = torch.nn.ParameterList() | 
		
	
		
			
			|  |  |  | for i, nt in enumerate(self.data.node_types): | 
		
	
		
			
			|  |  |  | reps = torch.eye(nt.count).to_sparse() | 
		
	
		
			
			|  |  |  | reps = torch.nn.Parameter(reps) | 
		
	
		
			
			|  |  |  | self.register_parameter('node_reps[%d]' % i, reps) | 
		
	
		
			
			|  |  |  | reps = torch.nn.Parameter(reps, requires_grad=False) | 
		
	
		
			
			|  |  |  | # self.register_parameter('node_reps[%d]' % i, reps) | 
		
	
		
			
			|  |  |  | self.node_reps.append(reps) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def forward(self, x) -> List[torch.nn.Parameter]: | 
		
	
	
		
			
				|  |  | 
 |