IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

batch.py 1.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from icosagon.declayer import Predictions
  2. import torch
  3. class PredictionsBatch(object):
  4. def __init__(self, pred: Predictions, part_type: str = 'train',
  5. batch_size: int = 100, shuffle: bool = False) -> None:
  6. if not isinstance(pred, Predictions):
  7. raise TypeError('pred must be an instance of Predictions')
  8. if part_type not in ['train', 'val', 'test']:
  9. raise ValueError('part_type must be set to train, val or test')
  10. batch_size = int(batch_size)
  11. shuffle = bool(shuffle)
  12. self.predictions = pred
  13. self.part_type = part_type
  14. self.batch_size = batch_size
  15. self.shuffle = shuffle
  16. def __iter__(self):
  17. edge_types = [('edges_pos', 1), ('edges_neg', 0),
  18. ('edges_back_pos', 1), ('edges_back_neg', 0)]
  19. input = []
  20. target = []
  21. for fam in self.predictions.relation_families:
  22. for rel in fam.relation_types:
  23. for (et, tgt) in edge_types:
  24. edge_pred = getattr(getattr(rel, et), self.part_type)
  25. input.append(edge_pred)
  26. target.append(torch.ones_like(edge_pred) * tgt)
  27. input = torch.cat(input)
  28. target = torch.cat(target)
  29. if self.shuffle:
  30. perm = torch.randperm(len(input))
  31. input = input[perm]
  32. target = target[perm]
  33. for i in range(0, len(input), self.batch_size):
  34. yield (input[i:i+self.batch_size], target[i:i+self.batch_size])