|  |  | @@ -22,6 +22,12 @@ def test_cross_entropy_loss_01(): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(prep_d.relation_families) == 1 | 
		
	
		
			
			|  |  |  | assert len(prep_d.relation_families[0].relation_types) == 1 | 
		
	
		
			
			|  |  |  | assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5 | 
		
	
		
			
			|  |  |  | assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0 | 
		
	
		
			
			|  |  |  | assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | rel_pred = RelationPredictions( | 
		
	
		
			
			|  |  |  | TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | 
		
	
		
			
			|  |  |  | TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | 
		
	
	
		
			
				|  |  | 
 |