IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

92 wiersze
4.0KB

  1. from icosagon.trainprep import PreparedData, \
  2. PreparedRelationFamily, \
  3. PreparedRelationType, \
  4. _empty_edge_list_tvt
  5. class BatchedData(PreparedData):
  6. def __init__(self, *args, **kwargs):
  7. super().__init__(*args, **kwargs)
  8. def batched_data_skeleton(data: PreparedData) -> BatchedData:
  9. if not isinstance(data, PreparedData):
  10. raise TypeError('data must be an instance of PreparedData')
  11. fam_skels = []
  12. for fam in data.relation_families:
  13. rel_types_skel = []
  14. for rel in fam.relation_types:
  15. rel_skel = PreparedRelationType(rel.name,
  16. rel.node_type_row, rel.node_type_column,
  17. rel.adjacency_matrix, rel.adjacency_matrix_backward,
  18. _empty_edge_list_tvt(), _empty_edge_list_tvt(),
  19. _empty_edge_list_tvt(), _empty_edge_list_tvt())
  20. rel_types_skel.append(rel_skel)
  21. fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
  22. fam.node_type_row, fam.node_type_column,
  23. fam.is_symmetric, fam.decoder_class,
  24. rel_types_skel))
  25. return BatchedData(data.node_types, fam_skels)
  26. class DataBatcher(object):
  27. def __init__(self, data: PreparedData, batch_size: int) -> None:
  28. self._check_params(data, batch_size)
  29. self.data = data
  30. self.batch_size = batch_size
  31. # def batched_data_iter(self, fam_idx: int, rel_idx: int,
  32. # part_type: str) -> BatchedData:
  33. #
  34. # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
  35. #
  36. # edges = getattr(rel.edges_pos, part_type)
  37. # for m in range(0, len(edges), self.batch_size):
  38. # batched_data = batched_data_skeleton(self.data)
  39. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
  40. # part_type, edges[m : m + self.batch_size])
  41. # yield batched_data
  42. #
  43. # edges = getattr(rel.edges_neg, part_type)
  44. # for m in range(0, len(edges), self.batch_size):
  45. # batched_data = batched_data_skeleton(self.data)
  46. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
  47. # part_type, edges[m : m + self.batch_size])
  48. # yield batched_data
  49. #
  50. # edges = getattr(rel.edges_pos_back, part_type)
  51. # for m in range(0, len(edges), self.batch_size):
  52. # batched_data = batched_data_skeleton(self.data)
  53. # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
  54. # part_type, edges[m : m + self.batch_size])
  55. # yield batched_data
  56. #
  57. # edges = getattr(rel.edges_neg_back, part_type)
  58. # for m in range(0, len(), self.batch_size):
  59. # batched_data = batched_data_skeleton(self.data)
  60. # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
  61. # edges[m : m + self.batch_size])
  62. # yield batched_data
  63. def __iter__(self) -> BatchedData:
  64. for i, fam in enumerate(self.data.relation_families):
  65. for k, rel in enumerate(fam.relation_types):
  66. for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
  67. for part_type in ['train', 'val', 'test']:
  68. edges = getattr(getattr(rel, edge_type), part_type)
  69. for m in range(0, len(edges), self.batch_size):
  70. batched_data = batched_data_skeleton(self.data)
  71. setattr(getattr(batched_data.relation_families[i].relation_types[k],
  72. edge_type), part_type, edges[m : m + self.batch_size])
  73. yield batched_data
  74. @staticmethod
  75. def _check_params(data, batch_size):
  76. if not isinstance(data, PreparedData):
  77. raise TypeError('data must be an instance of PreparedData')
  78. if not isinstance(batch_size, int):
  79. raise TypeError('batch_size must be an int')