diff --git a/tests/icosagon/test_batch.py b/tests/icosagon/test_batch.py index b6cd6d6..b5882db 100644 --- a/tests/icosagon/test_batch.py +++ b/tests/icosagon/test_batch.py @@ -92,6 +92,27 @@ def test_flatten_predictions_04(): pred_flat = flatten_predictions(pred, part_type='x') +def test_flatten_predictions_05(): + x = torch.rand(5000) + y = torch.cat([ x, x ]) + z = torch.cat([ torch.ones(5000), torch.zeros(5000) ]) + + rel_pred = RelationPredictions( + TrainValTest(x, torch.zeros(0), torch.zeros(0)), + TrainValTest(x, torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + for _ in range(10): + pred_flat = flatten_predictions(pred, part_type='train') + assert torch.all(pred_flat.predictions == y) + assert torch.all(pred_flat.truth == z) + assert pred_flat.part_type == 'train' + + def test_batch_indices_01(): indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))