IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.7KB

  1. import torch
  2. from icosagon.trainprep import PreparedData
  3. from icosagon.declayer import Predictions
  4. class CrossEntropyLoss(torch.nn.Module):
  5. def __init__(self, data: PreparedData, partition_type: str = 'train',
  6. reduction: str = 'sum', **kwargs) -> None:
  7. super().__init__(**kwargs)
  8. if not isinstance(data, PreparedData):
  9. raise TypeError('data must be an instance of PreparedData')
  10. if partition_type not in ['train', 'val', 'test']:
  11. raise ValueError('partition_type must be set to train, val or test')
  12. if reduction not in ['sum', 'mean']:
  13. raise ValueError('reduction must be set to sum or mean')
  14. self.data = data
  15. self.partition_type = partition_type
  16. self.reduction = reduction
  17. def forward(self, pred: Predictions) -> torch.Tensor:
  18. input = []
  19. target = []
  20. for fam in pred.relation_families:
  21. for rel in fam.relation_types:
  22. for edge_type in ['edges_pos', 'edges_back_pos']:
  23. x = getattr(getattr(rel, edge_type), self.partition_type)
  24. assert len(x.shape) == 1
  25. input.append(x)
  26. target.append(torch.ones_like(x))
  27. for edge_type in ['edges_neg', 'edges_back_neg']:
  28. x = getattr(getattr(rel, edge_type), self.partition_type)
  29. assert len(x.shape) == 1
  30. input.append(x)
  31. target.append(torch.zeros_like(x))
  32. input = torch.cat(input, dim=0)
  33. target = torch.cat(target, dim=0)
  34. res = torch.nn.functional.binary_cross_entropy(input, target,
  35. reduction=self.reduction)
  36. return res