|
|
@@ -0,0 +1,44 @@ |
|
|
|
import torch
|
|
|
|
from icosagon.trainprep import PreparedData
|
|
|
|
from icosagon.declayer import Predictions
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(torch.nn.Module):
|
|
|
|
def __init__(self, data: PreparedData, partition_type: str = 'train',
|
|
|
|
reduction: str = 'sum', **kwargs) -> None:
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
if not isinstance(data, PreparedData):
|
|
|
|
raise TypeError('data must be an instance of PreparedData')
|
|
|
|
|
|
|
|
if partition_type not in ['train', 'val', 'test']:
|
|
|
|
raise ValueError('partition_type must be set to train, val or test')
|
|
|
|
|
|
|
|
if reduction not in ['sum', 'mean']:
|
|
|
|
raise ValueError('reduction must be set to sum or mean')
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
self.partition_type = partition_type
|
|
|
|
self.reduction = reduction
|
|
|
|
|
|
|
|
def forward(self, pred: Predictions) -> torch.Tensor:
|
|
|
|
input = []
|
|
|
|
target = []
|
|
|
|
for fam in pred.relation_families:
|
|
|
|
for rel in fam.relation_types:
|
|
|
|
for edge_type in ['edges_pos', 'edges_back_pos']:
|
|
|
|
x = getattr(getattr(rel, edge_type), self.partition_type)
|
|
|
|
assert len(x.shape) == 1
|
|
|
|
input.append(x)
|
|
|
|
target.append(torch.ones_like(x))
|
|
|
|
for edge_type in ['edges_neg', 'edges_back_neg']:
|
|
|
|
x = getattr(getattr(rel, edge_type), self.partition_type)
|
|
|
|
assert len(x.shape) == 1
|
|
|
|
input.append(x)
|
|
|
|
target.append(torch.zeros_like(x))
|
|
|
|
input = torch.cat(input, dim=0)
|
|
|
|
target = torch.cat(target, dim=0)
|
|
|
|
res = torch.nn.functional.binary_cross_entropy(input, target,
|
|
|
|
reduction=self.reduction)
|
|
|
|
return res
|