|  |  | @@ -133,9 +133,10 @@ class Model(torch.nn.Module): | 
		
	
		
			
			|  |  |  | List[torch.Tensor]: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | cur_layer_repr = in_layer_repr | 
		
	
		
			
			|  |  |  | next_layer_repr = [ None ] * len(self.data.vertex_types) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i in range(len(self.layer_dimensions) - 1): | 
		
	
		
			
			|  |  |  | next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for _, et in self.data.edge_types.items(): | 
		
	
		
			
			|  |  |  | vt_row, vt_col = et.vertex_type_row, et.vertex_type_column | 
		
	
		
			
			|  |  |  | adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)] | 
		
	
	
		
			
				|  |  | @@ -158,13 +159,18 @@ class Model(torch.nn.Module): | 
		
	
		
			
			|  |  |  | self.data.vertex_types[vt_row].count, | 
		
	
		
			
			|  |  |  | self.layer_dimensions[i + 1]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('b, Layer:', i, 'x.shape:', x.shape) | 
		
	
		
			
			|  |  |  | print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | x = x.sum(dim=0) | 
		
	
		
			
			|  |  |  | x = torch.nn.functional.normalize(x, p=2, dim=1) | 
		
	
		
			
			|  |  |  | x = self.conv_activation(x) | 
		
	
		
			
			|  |  |  | # x = self.rel_activation(x) | 
		
	
		
			
			|  |  |  | print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | next_layer_repr[vt_row].append(x) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | next_layer_repr = [ self.conv_activation(sum(x)) \ | 
		
	
		
			
			|  |  |  | for x in next_layer_repr ] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | next_layer_repr[vt_row] = x | 
		
	
		
			
			|  |  |  | cur_layer_repr = next_layer_repr | 
		
	
		
			
			|  |  |  | return next_layer_repr | 
		
	
		
			
			|  |  |  |  | 
		
	
	
		
			
				|  |  | 
 |