|
|
@@ -31,7 +31,7 @@ def test_predictions_batch_01(): |
|
|
|
|
|
|
|
rel_pred = RelationPredictions(
|
|
|
|
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
|
|
|
|
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
|
|
|
|
)
|
|
|
@@ -40,7 +40,13 @@ def test_predictions_batch_01(): |
|
|
|
|
|
|
|
batch = PredictionsBatch(pred, part_type='train', batch_size=1)
|
|
|
|
count = 0
|
|
|
|
lst = []
|
|
|
|
for (input, target) in batch:
|
|
|
|
assert len(input) == 1
|
|
|
|
assert len(target) == 1
|
|
|
|
lst.append((input[0], target[0]))
|
|
|
|
count += 1
|
|
|
|
assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
|
|
|
|
(1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
|
|
|
|
|
|
|
|
assert count == 10
|