|
1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import torch
- from icosagon.trainprep import PreparedData
- from icosagon.declayer import Predictions
-
-
- class CrossEntropyLoss(torch.nn.Module):
- def __init__(self, data: PreparedData, partition_type: str = 'train',
- reduction: str = 'sum', **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- if partition_type not in ['train', 'val', 'test']:
- raise ValueError('partition_type must be set to train, val or test')
-
- if reduction not in ['sum', 'mean']:
- raise ValueError('reduction must be set to sum or mean')
-
- self.data = data
- self.partition_type = partition_type
- self.reduction = reduction
-
- def forward(self, pred: Predictions) -> torch.Tensor:
- input = []
- target = []
- for fam in pred.relation_families:
- for rel in fam.relation_types:
- for edge_type in ['edges_pos', 'edges_back_pos']:
- x = getattr(getattr(rel, edge_type), self.partition_type)
- assert len(x.shape) == 1
- input.append(x)
- target.append(torch.ones_like(x))
- for edge_type in ['edges_neg', 'edges_back_neg']:
- x = getattr(getattr(rel, edge_type), self.partition_type)
- assert len(x.shape) == 1
- input.append(x)
- target.append(torch.zeros_like(x))
- input = torch.cat(input, dim=0)
- target = torch.cat(target, dim=0)
- res = torch.nn.functional.binary_cross_entropy(input, target,
- reduction=self.reduction)
- return res
|