IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.5KB

  1. from icosagon.declayer import Predictions
  2. import torch
  3. class PredictionsBatch(object):
  4. def __init__(self, pred: Predictions, part_type: str = 'train',
  5. batch_size: int = 100, shuffle: bool = False) -> None:
  6. if not isinstance(pred, Predictions):
  7. raise TypeError('pred must be an instance of Predictions')
  8. if part_type not in ['train', 'val', 'test']:
  9. raise ValueError('part_type must be set to train, val or test')
  10. batch_size = int(batch_size)
  11. shuffle = bool(shuffle)
  12. self.predictions = pred
  13. self.part_type = part_type
  14. self.batch_size = batch_size
  15. self.shuffle = shuffle
  16. def __iter__(self):
  17. edge_types = [('edges_pos', 1), ('edges_neg', 0),
  18. ('edges_back_pos', 1), ('edges_back_neg', 0)]
  19. input = []
  20. target = []
  21. for fam in self.predictions.relation_families:
  22. for rel in fam.relation_types:
  23. for (et, tgt) in edge_types:
  24. edge_pred = getattr(getattr(rel, et), self.part_type)
  25. input.append(edge_pred)
  26. target.append(torch.ones_like(edge_pred) * tgt)
  27. input = torch.cat(input)
  28. target = torch.cat(target)
  29. if self.shuffle:
  30. perm = torch.randperm(len(input))
  31. input = input[perm]
  32. target = target[perm]
  33. for i in range(0, len(input), self.batch_size):
  34. yield (input[i:i+self.batch_size], target[i:i+self.batch_size])