IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
4.8KB

  1. from icosagon.trainprep import PreparedData, \
  2. PreparedRelationFamily, \
  3. PreparedRelationType, \
  4. _empty_edge_list_tvt
  5. import torch
  6. import random
  7. class BatchedData(PreparedData):
  8. def __init__(self, *args, **kwargs):
  9. super().__init__(*args, **kwargs)
  10. class BatchedDataPointer(object):
  11. def __init__(self, batched_data):
  12. self.batched_data = batched_data
  13. def batched_data_skeleton(data: PreparedData) -> BatchedData:
  14. if not isinstance(data, PreparedData):
  15. raise TypeError('data must be an instance of PreparedData')
  16. fam_skels = []
  17. for fam in data.relation_families:
  18. rel_types_skel = []
  19. for rel in fam.relation_types:
  20. rel_skel = PreparedRelationType(rel.name,
  21. rel.node_type_row, rel.node_type_column,
  22. rel.adjacency_matrix, rel.adjacency_matrix_backward,
  23. _empty_edge_list_tvt(), _empty_edge_list_tvt(),
  24. _empty_edge_list_tvt(), _empty_edge_list_tvt())
  25. rel_types_skel.append(rel_skel)
  26. fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
  27. fam.node_type_row, fam.node_type_column,
  28. fam.is_symmetric, fam.decoder_class,
  29. rel_types_skel))
  30. return BatchedData(data.node_types, fam_skels)
  31. class DataBatcher(object):
  32. def __init__(self, data: PreparedData, batch_size: int,
  33. shuffle: bool = True) -> None:
  34. self._check_params(data, batch_size)
  35. self.data = data
  36. self.batch_size = batch_size
  37. self.shuffle = shuffle
  38. # def batched_data_iter(self, fam_idx: int, rel_idx: int,
  39. # part_type: str) -> BatchedData:
  40. #
  41. # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
  42. #
  43. # edges = getattr(rel.edges_pos, part_type)
  44. # for m in range(0, len(edges), self.batch_size):
  45. # batched_data = batched_data_skeleton(self.data)
  46. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
  47. # part_type, edges[m : m + self.batch_size])
  48. # yield batched_data
  49. #
  50. # edges = getattr(rel.edges_neg, part_type)
  51. # for m in range(0, len(edges), self.batch_size):
  52. # batched_data = batched_data_skeleton(self.data)
  53. # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
  54. # part_type, edges[m : m + self.batch_size])
  55. # yield batched_data
  56. #
  57. # edges = getattr(rel.edges_pos_back, part_type)
  58. # for m in range(0, len(edges), self.batch_size):
  59. # batched_data = batched_data_skeleton(self.data)
  60. # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
  61. # part_type, edges[m : m + self.batch_size])
  62. # yield batched_data
  63. #
  64. # edges = getattr(rel.edges_neg_back, part_type)
  65. # for m in range(0, len(), self.batch_size):
  66. # batched_data = batched_data_skeleton(self.data)
  67. # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
  68. # edges[m : m + self.batch_size])
  69. # yield batched_data
  70. def __iter__(self) -> BatchedData:
  71. gen = self.shuffle_iter() \
  72. if self.shuffle \
  73. else self.iter_base()
  74. for batched_data in gen:
  75. yield batched_data
  76. def iter_base(self) -> BatchedData:
  77. for i, fam in enumerate(self.data.relation_families):
  78. for k, rel in enumerate(fam.relation_types):
  79. for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
  80. for part_type in ['train', 'val', 'test']:
  81. edges = getattr(getattr(rel, edge_type), part_type)
  82. if self.shuffle:
  83. perm = torch.randperm(len(edges))
  84. edges = edges[perm]
  85. for m in range(0, len(edges), self.batch_size):
  86. batched_data = batched_data_skeleton(self.data)
  87. setattr(getattr(batched_data.relation_families[i].relation_types[k],
  88. edge_type), part_type, edges[m : m + self.batch_size])
  89. yield batched_data
  90. def shuffle_iter(self) -> BatchedData:
  91. res = list(self.iter_base())
  92. random.shuffle(res)
  93. for batched_data in res:
  94. yield batched_data
  95. @staticmethod
  96. def _check_params(data, batch_size):
  97. if not isinstance(data, PreparedData):
  98. raise TypeError('data must be an instance of PreparedData')
  99. if not isinstance(batch_size, int):
  100. raise TypeError('batch_size must be an int')