from icosagon.trainprep import PreparedData, \ PreparedRelationFamily, \ PreparedRelationType, \ _empty_edge_list_tvt import torch import random class BatchedData(PreparedData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class BatchedDataPointer(object): def __init__(self, batched_data): self.batched_data = batched_data def batched_data_skeleton(data: PreparedData) -> BatchedData: if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') fam_skels = [] for fam in data.relation_families: rel_types_skel = [] for rel in fam.relation_types: rel_skel = PreparedRelationType(rel.name, rel.node_type_row, rel.node_type_column, rel.adjacency_matrix, rel.adjacency_matrix_backward, _empty_edge_list_tvt(), _empty_edge_list_tvt(), _empty_edge_list_tvt(), _empty_edge_list_tvt()) rel_types_skel.append(rel_skel) fam_skels.append(PreparedRelationFamily(fam.data, fam.name, fam.node_type_row, fam.node_type_column, fam.is_symmetric, fam.decoder_class, rel_types_skel)) return BatchedData(data.node_types, fam_skels) class DataBatcher(object): def __init__(self, data: PreparedData, batch_size: int, shuffle: bool = True) -> None: self._check_params(data, batch_size) self.data = data self.batch_size = batch_size self.shuffle = shuffle # def batched_data_iter(self, fam_idx: int, rel_idx: int, # part_type: str) -> BatchedData: # # rel = self.data.relation_families[fam_idx].relation_types[rel_idx] # # edges = getattr(rel.edges_pos, part_type) # for m in range(0, len(edges), self.batch_size): # batched_data = batched_data_skeleton(self.data) # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos, # part_type, edges[m : m + self.batch_size]) # yield batched_data # # edges = getattr(rel.edges_neg, part_type) # for m in range(0, len(edges), self.batch_size): # batched_data = batched_data_skeleton(self.data) # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg, # part_type, edges[m : m + self.batch_size]) # yield batched_data # # edges = getattr(rel.edges_pos_back, part_type) # for m in range(0, len(edges), self.batch_size): # batched_data = batched_data_skeleton(self.data) # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back, # part_type, edges[m : m + self.batch_size]) # yield batched_data # # edges = getattr(rel.edges_neg_back, part_type) # for m in range(0, len(), self.batch_size): # batched_data = batched_data_skeleton(self.data) # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back, # edges[m : m + self.batch_size]) # yield batched_data def __iter__(self) -> BatchedData: gen = self.shuffle_iter() \ if self.shuffle \ else self.iter_base() for batched_data in gen: yield batched_data def iter_base(self) -> BatchedData: for i, fam in enumerate(self.data.relation_families): for k, rel in enumerate(fam.relation_types): for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: for part_type in ['train', 'val', 'test']: edges = getattr(getattr(rel, edge_type), part_type) if self.shuffle: perm = torch.randperm(len(edges)) edges = edges[perm] for m in range(0, len(edges), self.batch_size): batched_data = batched_data_skeleton(self.data) setattr(getattr(batched_data.relation_families[i].relation_types[k], edge_type), part_type, edges[m : m + self.batch_size]) yield batched_data def shuffle_iter(self) -> BatchedData: res = list(self.iter_base()) random.shuffle(res) for batched_data in res: yield batched_data @staticmethod def _check_params(data, batch_size): if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') if not isinstance(batch_size, int): raise TypeError('batch_size must be an int')