|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from icosagon.loss import CrossEntropyLoss
- from icosagon.declayer import Predictions, \
- RelationFamilyPredictions, \
- RelationPredictions
- from icosagon.data import Data
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- import torch
-
-
- def test_cross_entropy_loss_01():
- d = Data()
- d.add_node_type('Dummy', 5)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.tensor([
- [0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0],
- [0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1],
- [0, 1, 0, 0, 0]
- ], dtype=torch.float32))
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- assert len(prep_d.relation_families) == 1
- assert len(prep_d.relation_families[0].relation_types) == 1
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
-
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- loss = CrossEntropyLoss(prep_d)
- print('loss: %.7f' % loss(pred))
- assert torch.abs(loss(pred) - 55.262043) < 0.000001
-
- loss = CrossEntropyLoss(prep_d, reduction='mean')
- print('loss: %.7f' % loss(pred))
- assert torch.abs(loss(pred) - 11.0524082) < 0.000001
|