IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.7KB

  1. from icosagon.batch import PredictionsBatch
  2. from icosagon.declayer import Predictions, \
  3. RelationPredictions, \
  4. RelationFamilyPredictions
  5. from icosagon.trainprep import prepare_training, \
  6. TrainValTest
  7. from icosagon.data import Data
  8. import torch
  9. def test_predictions_batch_01():
  10. d = Data()
  11. d.add_node_type('Dummy', 5)
  12. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Rel', torch.tensor([
  14. [0, 1, 0, 0, 0],
  15. [1, 0, 0, 0, 0],
  16. [0, 0, 0, 1, 0],
  17. [0, 0, 0, 0, 1],
  18. [0, 1, 0, 0, 0]
  19. ], dtype=torch.float32))
  20. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  21. assert len(prep_d.relation_families) == 1
  22. assert len(prep_d.relation_families[0].relation_types) == 1
  23. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
  24. assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
  25. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
  26. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
  27. rel_pred = RelationPredictions(
  28. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  29. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  30. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  31. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  32. )
  33. fam_pred = RelationFamilyPredictions([ rel_pred ])
  34. pred = Predictions([ fam_pred ])
  35. batch = PredictionsBatch(pred, part_type='train', batch_size=1)
  36. count = 0
  37. for (input, target) in batch:
  38. count += 1
  39. assert count == 10