IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.8KB

  1. from icosagon.loss import CrossEntropyLoss
  2. from icosagon.declayer import Predictions, \
  3. RelationFamilyPredictions, \
  4. RelationPredictions
  5. from icosagon.data import Data
  6. from icosagon.trainprep import prepare_training, \
  7. TrainValTest
  8. import torch
  9. def test_cross_entropy_loss_01():
  10. d = Data()
  11. d.add_node_type('Dummy', 5)
  12. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Rel', torch.tensor([
  14. [0, 1, 0, 0, 0],
  15. [1, 0, 0, 0, 0],
  16. [0, 0, 0, 1, 0],
  17. [0, 0, 0, 0, 1],
  18. [0, 1, 0, 0, 0]
  19. ], dtype=torch.float32))
  20. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  21. assert len(prep_d.relation_families) == 1
  22. assert len(prep_d.relation_families[0].relation_types) == 1
  23. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
  24. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
  25. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
  26. rel_pred = RelationPredictions(
  27. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  28. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  29. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  30. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  31. )
  32. fam_pred = RelationFamilyPredictions([ rel_pred ])
  33. pred = Predictions([ fam_pred ])
  34. loss = CrossEntropyLoss(prep_d)
  35. print('loss: %.7f' % loss(pred))
  36. assert torch.abs(loss(pred) - 55.262043) < 0.000001
  37. loss = CrossEntropyLoss(prep_d, reduction='mean')
  38. print('loss: %.7f' % loss(pred))
  39. assert torch.abs(loss(pred) - 11.0524082) < 0.000001