| @@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \ | |||||
| PreparedRelationFamily, \ | PreparedRelationFamily, \ | ||||
| PreparedRelationType, \ | PreparedRelationType, \ | ||||
| _empty_edge_list_tvt | _empty_edge_list_tvt | ||||
| import torch | |||||
| import random | |||||
| class BatchedData(PreparedData): | class BatchedData(PreparedData): | ||||
| @@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData: | |||||
| class DataBatcher(object): | class DataBatcher(object): | ||||
| def __init__(self, data: PreparedData, batch_size: int) -> None: | |||||
| def __init__(self, data: PreparedData, batch_size: int, | |||||
| shuffle: bool = True) -> None: | |||||
| self._check_params(data, batch_size) | self._check_params(data, batch_size) | ||||
| self.data = data | self.data = data | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.shuffle = shuffle | |||||
| # def batched_data_iter(self, fam_idx: int, rel_idx: int, | # def batched_data_iter(self, fam_idx: int, rel_idx: int, | ||||
| # part_type: str) -> BatchedData: | # part_type: str) -> BatchedData: | ||||
| @@ -71,17 +75,34 @@ class DataBatcher(object): | |||||
| # yield batched_data | # yield batched_data | ||||
| def __iter__(self) -> BatchedData: | def __iter__(self) -> BatchedData: | ||||
| gen = self.shuffle_iter() \ | |||||
| if self.shuffle \ | |||||
| else self.iter_base() | |||||
| for batched_data in gen: | |||||
| yield batched_data | |||||
| def iter_base(self) -> BatchedData: | |||||
| for i, fam in enumerate(self.data.relation_families): | for i, fam in enumerate(self.data.relation_families): | ||||
| for k, rel in enumerate(fam.relation_types): | for k, rel in enumerate(fam.relation_types): | ||||
| for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: | for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: | ||||
| for part_type in ['train', 'val', 'test']: | for part_type in ['train', 'val', 'test']: | ||||
| edges = getattr(getattr(rel, edge_type), part_type) | edges = getattr(getattr(rel, edge_type), part_type) | ||||
| if self.shuffle: | |||||
| perm = torch.randperm(len(edges)) | |||||
| edges = edges[perm] | |||||
| for m in range(0, len(edges), self.batch_size): | for m in range(0, len(edges), self.batch_size): | ||||
| batched_data = batched_data_skeleton(self.data) | batched_data = batched_data_skeleton(self.data) | ||||
| setattr(getattr(batched_data.relation_families[i].relation_types[k], | setattr(getattr(batched_data.relation_families[i].relation_types[k], | ||||
| edge_type), part_type, edges[m : m + self.batch_size]) | edge_type), part_type, edges[m : m + self.batch_size]) | ||||
| yield batched_data | yield batched_data | ||||
| def shuffle_iter(self) -> BatchedData: | |||||
| res = list(self.iter_base()) | |||||
| random.shuffle(res) | |||||
| for batched_data in res: | |||||
| yield batched_data | |||||
| @staticmethod | @staticmethod | ||||
| def _check_params(data, batch_size): | def _check_params(data, batch_size): | ||||
| if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||