|  |  | @@ -92,6 +92,27 @@ def test_flatten_predictions_04(): | 
		
	
		
			
			|  |  |  | pred_flat = flatten_predictions(pred, part_type='x') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_flatten_predictions_05(): | 
		
	
		
			
			|  |  |  | x = torch.rand(5000) | 
		
	
		
			
			|  |  |  | y = torch.cat([ x, x ]) | 
		
	
		
			
			|  |  |  | z = torch.cat([ torch.ones(5000), torch.zeros(5000) ]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | rel_pred = RelationPredictions( | 
		
	
		
			
			|  |  |  | TrainValTest(x, torch.zeros(0), torch.zeros(0)), | 
		
	
		
			
			|  |  |  | TrainValTest(x, torch.zeros(0), torch.zeros(0)), | 
		
	
		
			
			|  |  |  | TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | 
		
	
		
			
			|  |  |  | TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | fam_pred = RelationFamilyPredictions([ rel_pred ]) | 
		
	
		
			
			|  |  |  | pred = Predictions([ fam_pred ]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for _ in range(10): | 
		
	
		
			
			|  |  |  | pred_flat = flatten_predictions(pred, part_type='train') | 
		
	
		
			
			|  |  |  | assert torch.all(pred_flat.predictions == y) | 
		
	
		
			
			|  |  |  | assert torch.all(pred_flat.truth == z) | 
		
	
		
			
			|  |  |  | assert pred_flat.part_type == 'train' | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_batch_indices_01(): | 
		
	
		
			
			|  |  |  | indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | 
		
	
		
			
			|  |  |  | assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | 
		
	
	
		
			
				|  |  | 
 |