diff --git a/src/icosagon/loss.py b/src/icosagon/loss.py new file mode 100644 index 0000000..eb2e0fe --- /dev/null +++ b/src/icosagon/loss.py @@ -0,0 +1,44 @@ +import torch +from icosagon.trainprep import PreparedData +from icosagon.declayer import Predictions + + +class CrossEntropyLoss(torch.nn.Module): + def __init__(self, data: PreparedData, partition_type: str = 'train', + reduction: str = 'sum', **kwargs) -> None: + + super().__init__(**kwargs) + + if not isinstance(data, PreparedData): + raise TypeError('data must be an instance of PreparedData') + + if partition_type not in ['train', 'val', 'test']: + raise ValueError('partition_type must be set to train, val or test') + + if reduction not in ['sum', 'mean']: + raise ValueError('reduction must be set to sum or mean') + + self.data = data + self.partition_type = partition_type + self.reduction = reduction + + def forward(self, pred: Predictions) -> torch.Tensor: + input = [] + target = [] + for fam in pred.relation_families: + for rel in fam.relation_types: + for edge_type in ['edges_pos', 'edges_back_pos']: + x = getattr(getattr(rel, edge_type), self.partition_type) + assert len(x.shape) == 1 + input.append(x) + target.append(torch.ones_like(x)) + for edge_type in ['edges_neg', 'edges_back_neg']: + x = getattr(getattr(rel, edge_type), self.partition_type) + assert len(x.shape) == 1 + input.append(x) + target.append(torch.zeros_like(x)) + input = torch.cat(input, dim=0) + target = torch.cat(target, dim=0) + res = torch.nn.functional.binary_cross_entropy(input, target, + reduction=self.reduction) + return res diff --git a/tests/icosagon/test_loss.py b/tests/icosagon/test_loss.py new file mode 100644 index 0000000..07e5e5c --- /dev/null +++ b/tests/icosagon/test_loss.py @@ -0,0 +1,40 @@ +from icosagon.loss import CrossEntropyLoss +from icosagon.declayer import Predictions, \ + RelationFamilyPredictions, \ + RelationPredictions +from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest +import torch + + +def test_cross_entropy_loss_01(): + d = Data() + d.add_node_type('Dummy', 5) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.tensor([ + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0] + ], dtype=torch.float32)) + + prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) + + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + loss = CrossEntropyLoss(prep_d) + print('loss: %.7f' % loss(pred)) + assert torch.abs(loss(pred) - 55.262043) < 0.000001 + + loss = CrossEntropyLoss(prep_d, reduction='mean') + print('loss: %.7f' % loss(pred)) + assert torch.abs(loss(pred) - 11.0524082) < 0.000001