|
|
@@ -92,6 +92,27 @@ def test_flatten_predictions_04(): |
|
|
|
pred_flat = flatten_predictions(pred, part_type='x')
|
|
|
|
|
|
|
|
|
|
|
|
def test_flatten_predictions_05():
|
|
|
|
x = torch.rand(5000)
|
|
|
|
y = torch.cat([ x, x ])
|
|
|
|
z = torch.cat([ torch.ones(5000), torch.zeros(5000) ])
|
|
|
|
|
|
|
|
rel_pred = RelationPredictions(
|
|
|
|
TrainValTest(x, torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(x, torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
|
|
|
|
)
|
|
|
|
fam_pred = RelationFamilyPredictions([ rel_pred ])
|
|
|
|
pred = Predictions([ fam_pred ])
|
|
|
|
|
|
|
|
for _ in range(10):
|
|
|
|
pred_flat = flatten_predictions(pred, part_type='train')
|
|
|
|
assert torch.all(pred_flat.predictions == y)
|
|
|
|
assert torch.all(pred_flat.truth == z)
|
|
|
|
assert pred_flat.part_type == 'train'
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch_indices_01():
|
|
|
|
indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
|
|
|
|
assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
|
|
|
|