IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.2KB

  1. from .data import Data
  2. from .model import TrainingBatch
  3. import torch
  4. def _shuffle(x: torch.Tensor) -> torch.Tensor:
  5. order = torch.randperm(len(x))
  6. return x[order]
  7. class Batcher(object):
  8. def __init__(self, data: Data, batch_size: int=512,
  9. shuffle: bool=True) -> None:
  10. if not isinstance(data, Data):
  11. raise TypeError('data must be an instance of Data')
  12. self.data = data
  13. self.batch_size = int(batch_size)
  14. self.shuffle = bool(shuffle)
  15. def __iter__(self) -> TrainingBatch:
  16. edge_types = list(self.data.edge_types.values())
  17. edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
  18. for adj_mat in et.adjacency_matrices ] \
  19. for et in edge_types ]
  20. if self.shuffle:
  21. edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
  22. for edge_lst in edge_lists ]
  23. offsets = [ [ 0 ] * len(et.adjacency_matrices) \
  24. for et in edge_types ]
  25. while True:
  26. candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
  27. if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
  28. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
  29. if len(candidates) == 0:
  30. break
  31. edge_idx = torch.randint(0, len(candidates), (1,)).item()
  32. edge_idx = candidates[edge_idx]
  33. candidates = [ rel_idx \
  34. for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
  35. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
  36. rel_idx = torch.randint(0, len(candidates), (1,)).item()
  37. rel_idx = candidates[rel_idx]
  38. lst = edge_lists[edge_idx][rel_idx]
  39. et = edge_types[edge_idx]
  40. ofs = offsets[edge_idx][rel_idx]
  41. lst = lst[ofs:ofs+self.batch_size]
  42. offsets[edge_idx][rel_idx] += self.batch_size
  43. b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
  44. rel_idx, lst, torch.full((len(lst),), self.data.target_value,
  45. dtype=torch.float32))
  46. yield b