IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.3KB

  1. from icosagon.declayer import Predictions
  2. import torch
  3. class PredictionsBatch(object):
  4. def __init__(self, pred: Predictions, part_type: str = 'train',
  5. batch_size: int = 100) -> None:
  6. if not isinstance(pred, Predictions):
  7. raise TypeError('pred must be an instance of Predictions')
  8. if part_type not in ['train', 'val', 'test']:
  9. raise ValueError('part_type must be set to train, val or test')
  10. batch_size = int(batch_size)
  11. self.predictions = pred
  12. self.part_type = part_type
  13. self.batch_size = batch_size
  14. def __iter__(self):
  15. edge_types = [('edges_pos', 1), ('edges_neg', 0),
  16. ('edges_back_pos', 1), ('edges_back_neg', 0)]
  17. input = []
  18. target = []
  19. for fam in self.predictions.relation_families:
  20. for rel in fam.relation_types:
  21. for (et, tgt) in edge_types:
  22. edge_pred = getattr(getattr(rel, et), self.part_type)
  23. input.append(edge_pred)
  24. target.append(torch.ones_like(edge_pred) * tgt)
  25. input = torch.cat(input)
  26. target = torch.cat(target)
  27. for i in range(0, len(input), self.batch_size):
  28. yield (input[i:i+self.batch_size], target[i:i+self.batch_size])