| @@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \ | |||
| PreparedRelationFamily, \ | |||
| PreparedRelationType, \ | |||
| _empty_edge_list_tvt | |||
| import torch | |||
| import random | |||
| class BatchedData(PreparedData): | |||
| @@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData: | |||
| class DataBatcher(object): | |||
| def __init__(self, data: PreparedData, batch_size: int) -> None: | |||
| def __init__(self, data: PreparedData, batch_size: int, | |||
| shuffle: bool = True) -> None: | |||
| self._check_params(data, batch_size) | |||
| self.data = data | |||
| self.batch_size = batch_size | |||
| self.shuffle = shuffle | |||
| # def batched_data_iter(self, fam_idx: int, rel_idx: int, | |||
| # part_type: str) -> BatchedData: | |||
| @@ -71,17 +75,34 @@ class DataBatcher(object): | |||
| # yield batched_data | |||
| def __iter__(self) -> BatchedData: | |||
| gen = self.shuffle_iter() \ | |||
| if self.shuffle \ | |||
| else self.iter_base() | |||
| for batched_data in gen: | |||
| yield batched_data | |||
| def iter_base(self) -> BatchedData: | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| for k, rel in enumerate(fam.relation_types): | |||
| for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: | |||
| for part_type in ['train', 'val', 'test']: | |||
| edges = getattr(getattr(rel, edge_type), part_type) | |||
| if self.shuffle: | |||
| perm = torch.randperm(len(edges)) | |||
| edges = edges[perm] | |||
| for m in range(0, len(edges), self.batch_size): | |||
| batched_data = batched_data_skeleton(self.data) | |||
| setattr(getattr(batched_data.relation_families[i].relation_types[k], | |||
| edge_type), part_type, edges[m : m + self.batch_size]) | |||
| yield batched_data | |||
| def shuffle_iter(self) -> BatchedData: | |||
| res = list(self.iter_base()) | |||
| random.shuffle(res) | |||
| for batched_data in res: | |||
| yield batched_data | |||
| @staticmethod | |||
| def _check_params(data, batch_size): | |||
| if not isinstance(data, PreparedData): | |||