|  |  | @@ -232,3 +232,56 @@ def test_decode_layer_05(): | 
		
	
		
			
			|  |  |  | # assert isinstance(rel_pred.edges_neg, TrainValTest) | 
		
	
		
			
			|  |  |  | # assert isinstance(rel_pred.edges_back_pos, TrainValTest) | 
		
	
		
			
			|  |  |  | # assert isinstance(rel_pred.edges_back_neg, TrainValTest) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_decode_layer_parameter_count_01(): | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 100) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Relation 1', | 
		
	
		
			
			|  |  |  | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=lambda x: x) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 2 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_decode_layer_parameter_count_02(): | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 100) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Relation 1', | 
		
	
		
			
			|  |  |  | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Relation 2', | 
		
	
		
			
			|  |  |  | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=lambda x: x) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 3 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_decode_layer_parameter_count_03(): | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 100) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for _ in range(2): | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Relation 1', | 
		
	
		
			
			|  |  |  | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Relation 2', | 
		
	
		
			
			|  |  |  | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1., | 
		
	
		
			
			|  |  |  | activation=lambda x: x) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(dec.parameters())) == 6 |