IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

113 строки
4.7KB

  1. from icosagon.trainprep import PreparedData, \
  2. PreparedRelationFamily, \
  3. PreparedRelationType, \
  4. _empty_edge_list_tvt
  5. import torch
  6. import random
  7. class BatchedData(PreparedData):
  8. def __init__(self, *args, **kwargs):
  9. super().__init__(*args, **kwargs)
  10. def batched_data_skeleton(data: PreparedData) -> BatchedData:
  11. if not isinstance(data, PreparedData):
  12. raise TypeError('data must be an instance of PreparedData')
  13. fam_skels = []
  14. for fam in data.relation_families:
  15. rel_types_skel = []
  16. for rel in fam.relation_types:
  17. rel_skel = PreparedRelationType(rel.name,
  18. rel.node_type_row, rel.node_type_column,
  19. rel.adjacency_matrix, rel.adjacency_matrix_backward,
  20. _empty_edge_list_tvt(), _empty_edge_list_tvt(),
  21. _empty_edge_list_tvt(), _empty_edge_list_tvt())
  22. rel_types_skel.append(rel_skel)
  23. fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
  24. fam.node_type_row, fam.node_type_column,
  25. fam.is_symmetric, fam.decoder_class,
  26. rel_types_skel))
  27. return BatchedData(data.node_types, fam_skels)
  28. class DataBatcher(object):
  29. def __init__(self, data: PreparedData, batch_size: int,
  30. shuffle: bool = True) -> None:
  31. self._check_params(data, batch_size)
  32. self.data = data
  33. self.batch_size = batch_size
  34. self.shuffle = shuffle
  35. # def batched_data_iter(self, fam_idx: int, rel_idx: int,
  36. # part_type: str) -> BatchedData:
  37. #
  38. # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
  39. #
  40. # edges = getattr(rel.edges_pos, part_type)
  41. # for m in range(0, len(edges), self.batch_size):
  42. # batched_data = batched_data_skeleton(self.data)
  43. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
  44. # part_type, edges[m : m + self.batch_size])
  45. # yield batched_data
  46. #
  47. # edges = getattr(rel.edges_neg, part_type)
  48. # for m in range(0, len(edges), self.batch_size):
  49. # batched_data = batched_data_skeleton(self.data)
  50. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
  51. # part_type, edges[m : m + self.batch_size])
  52. # yield batched_data
  53. #
  54. # edges = getattr(rel.edges_pos_back, part_type)
  55. # for m in range(0, len(edges), self.batch_size):
  56. # batched_data = batched_data_skeleton(self.data)
  57. # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
  58. # part_type, edges[m : m + self.batch_size])
  59. # yield batched_data
  60. #
  61. # edges = getattr(rel.edges_neg_back, part_type)
  62. # for m in range(0, len(), self.batch_size):
  63. # batched_data = batched_data_skeleton(self.data)
  64. # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
  65. # edges[m : m + self.batch_size])
  66. # yield batched_data
  67. def __iter__(self) -> BatchedData:
  68. gen = self.shuffle_iter() \
  69. if self.shuffle \
  70. else self.iter_base()
  71. for batched_data in gen:
  72. yield batched_data
  73. def iter_base(self) -> BatchedData:
  74. for i, fam in enumerate(self.data.relation_families):
  75. for k, rel in enumerate(fam.relation_types):
  76. for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
  77. for part_type in ['train', 'val', 'test']:
  78. edges = getattr(getattr(rel, edge_type), part_type)
  79. if self.shuffle:
  80. perm = torch.randperm(len(edges))
  81. edges = edges[perm]
  82. for m in range(0, len(edges), self.batch_size):
  83. batched_data = batched_data_skeleton(self.data)
  84. setattr(getattr(batched_data.relation_families[i].relation_types[k],
  85. edge_type), part_type, edges[m : m + self.batch_size])
  86. yield batched_data
  87. def shuffle_iter(self) -> BatchedData:
  88. res = list(self.iter_base())
  89. random.shuffle(res)
  90. for batched_data in res:
  91. yield batched_data
  92. @staticmethod
  93. def _check_params(data, batch_size):
  94. if not isinstance(data, PreparedData):
  95. raise TypeError('data must be an instance of PreparedData')
  96. if not isinstance(batch_size, int):
  97. raise TypeError('batch_size must be an int')