| @@ -92,6 +92,27 @@ def test_flatten_predictions_04(): | |||||
| pred_flat = flatten_predictions(pred, part_type='x') | pred_flat = flatten_predictions(pred, part_type='x') | ||||
| def test_flatten_predictions_05(): | |||||
| x = torch.rand(5000) | |||||
| y = torch.cat([ x, x ]) | |||||
| z = torch.cat([ torch.ones(5000), torch.zeros(5000) ]) | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(x, torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(x, torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| for _ in range(10): | |||||
| pred_flat = flatten_predictions(pred, part_type='train') | |||||
| assert torch.all(pred_flat.predictions == y) | |||||
| assert torch.all(pred_flat.truth == z) | |||||
| assert pred_flat.part_type == 'train' | |||||
| def test_batch_indices_01(): | def test_batch_indices_01(): | ||||
| indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | ||||
| assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | ||||