import torch from icosagon.trainprep import PreparedData from icosagon.declayer import Predictions class CrossEntropyLoss(torch.nn.Module): def __init__(self, data: PreparedData, partition_type: str = 'train', reduction: str = 'sum', **kwargs) -> None: super().__init__(**kwargs) if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') if partition_type not in ['train', 'val', 'test']: raise ValueError('partition_type must be set to train, val or test') if reduction not in ['sum', 'mean']: raise ValueError('reduction must be set to sum or mean') self.data = data self.partition_type = partition_type self.reduction = reduction def forward(self, pred: Predictions) -> torch.Tensor: input = [] target = [] for fam in pred.relation_families: for rel in fam.relation_types: for edge_type in ['edges_pos', 'edges_back_pos']: x = getattr(getattr(rel, edge_type), self.partition_type) assert len(x.shape) == 1 input.append(x) target.append(torch.ones_like(x)) for edge_type in ['edges_neg', 'edges_back_neg']: x = getattr(getattr(rel, edge_type), self.partition_type) assert len(x.shape) == 1 input.append(x) target.append(torch.zeros_like(x)) input = torch.cat(input, dim=0) target = torch.cat(target, dim=0) res = torch.nn.functional.binary_cross_entropy(input, target, reduction=self.reduction) return res