|  |  | @@ -210,3 +210,31 @@ def test_empty_inner_product_decoder_01(): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i in range(len(res)): | 
		
	
		
			
			|  |  |  | assert res[i].shape == (0,) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_dedicom_decoder_parameter_count_01(): | 
		
	
		
			
			|  |  |  | dec = DEDICOMDecoder(32, 7, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=torch.sigmoid) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 8 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_dist_mult_decoder_parameter_count_01(): | 
		
	
		
			
			|  |  |  | dec = DistMultDecoder(32, 7, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=torch.sigmoid) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 7 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_bilinear_decoder_parameter_count_01(): | 
		
	
		
			
			|  |  |  | dec = BilinearDecoder(32, 7, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=torch.sigmoid) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 7 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_inner_product_decoder_parameter_count_01(): | 
		
	
		
			
			|  |  |  | dec = InnerProductDecoder(32, 7, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=torch.sigmoid) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 0 |