IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

41 lines
1.4KB

  1. from icosagon.loss import CrossEntropyLoss
  2. from icosagon.declayer import Predictions, \
  3. RelationFamilyPredictions, \
  4. RelationPredictions
  5. from icosagon.data import Data
  6. from icosagon.trainprep import prepare_training, \
  7. TrainValTest
  8. import torch
  9. def test_cross_entropy_loss_01():
  10. d = Data()
  11. d.add_node_type('Dummy', 5)
  12. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Rel', torch.tensor([
  14. [0, 1, 0, 0, 0],
  15. [1, 0, 0, 0, 0],
  16. [0, 0, 0, 1, 0],
  17. [0, 0, 0, 0, 1],
  18. [0, 1, 0, 0, 0]
  19. ], dtype=torch.float32))
  20. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  21. rel_pred = RelationPredictions(
  22. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  23. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  24. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  25. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  26. )
  27. fam_pred = RelationFamilyPredictions([ rel_pred ])
  28. pred = Predictions([ fam_pred ])
  29. loss = CrossEntropyLoss(prep_d)
  30. print('loss: %.7f' % loss(pred))
  31. assert torch.abs(loss(pred) - 55.262043) < 0.000001
  32. loss = CrossEntropyLoss(prep_d, reduction='mean')
  33. print('loss: %.7f' % loss(pred))
  34. assert torch.abs(loss(pred) - 11.0524082) < 0.000001