IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

Add test_flatten_predictions_05().

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
a6de8c7846
1 fichiers modifiés avec 21 ajouts et 0 suppressions
  1. +21
    -0
      tests/icosagon/test_batch.py

+ 21
- 0
tests/icosagon/test_batch.py Voir le fichier

@@ -92,6 +92,27 @@ def test_flatten_predictions_04():
pred_flat = flatten_predictions(pred, part_type='x')
def test_flatten_predictions_05():
x = torch.rand(5000)
y = torch.cat([ x, x ])
z = torch.cat([ torch.ones(5000), torch.zeros(5000) ])
rel_pred = RelationPredictions(
TrainValTest(x, torch.zeros(0), torch.zeros(0)),
TrainValTest(x, torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
for _ in range(10):
pred_flat = flatten_predictions(pred, part_type='train')
assert torch.all(pred_flat.predictions == y)
assert torch.all(pred_flat.truth == z)
assert pred_flat.part_type == 'train'
def test_batch_indices_01():
indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))


Chargement…
Annuler
Enregistrer