|  |  | @@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \ | 
		
	
		
			
			|  |  |  | PreparedRelationFamily, \ | 
		
	
		
			
			|  |  |  | PreparedRelationType, \ | 
		
	
		
			
			|  |  |  | _empty_edge_list_tvt | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | import random | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class BatchedData(PreparedData): | 
		
	
	
		
			
				|  |  | @@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class DataBatcher(object): | 
		
	
		
			
			|  |  |  | def __init__(self, data: PreparedData, batch_size: int) -> None: | 
		
	
		
			
			|  |  |  | def __init__(self, data: PreparedData, batch_size: int, | 
		
	
		
			
			|  |  |  | shuffle: bool = True) -> None: | 
		
	
		
			
			|  |  |  | self._check_params(data, batch_size) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.data = data | 
		
	
		
			
			|  |  |  | self.batch_size = batch_size | 
		
	
		
			
			|  |  |  | self.shuffle = shuffle | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # def batched_data_iter(self, fam_idx: int, rel_idx: int, | 
		
	
		
			
			|  |  |  | #     part_type: str) -> BatchedData: | 
		
	
	
		
			
				|  |  | @@ -71,17 +75,34 @@ class DataBatcher(object): | 
		
	
		
			
			|  |  |  | #         yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __iter__(self) -> BatchedData: | 
		
	
		
			
			|  |  |  | gen = self.shuffle_iter() \ | 
		
	
		
			
			|  |  |  | if self.shuffle \ | 
		
	
		
			
			|  |  |  | else self.iter_base() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for batched_data in gen: | 
		
	
		
			
			|  |  |  | yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def iter_base(self) -> BatchedData: | 
		
	
		
			
			|  |  |  | for i, fam in enumerate(self.data.relation_families): | 
		
	
		
			
			|  |  |  | for k, rel in enumerate(fam.relation_types): | 
		
	
		
			
			|  |  |  | for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: | 
		
	
		
			
			|  |  |  | for part_type in ['train', 'val', 'test']: | 
		
	
		
			
			|  |  |  | edges = getattr(getattr(rel, edge_type), part_type) | 
		
	
		
			
			|  |  |  | if self.shuffle: | 
		
	
		
			
			|  |  |  | perm = torch.randperm(len(edges)) | 
		
	
		
			
			|  |  |  | edges = edges[perm] | 
		
	
		
			
			|  |  |  | for m in range(0, len(edges), self.batch_size): | 
		
	
		
			
			|  |  |  | batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | setattr(getattr(batched_data.relation_families[i].relation_types[k], | 
		
	
		
			
			|  |  |  | edge_type), part_type, edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def shuffle_iter(self) -> BatchedData: | 
		
	
		
			
			|  |  |  | res = list(self.iter_base()) | 
		
	
		
			
			|  |  |  | random.shuffle(res) | 
		
	
		
			
			|  |  |  | for batched_data in res: | 
		
	
		
			
			|  |  |  | yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | @staticmethod | 
		
	
		
			
			|  |  |  | def _check_params(data, batch_size): | 
		
	
		
			
			|  |  |  | if not isinstance(data, PreparedData): | 
		
	
	
		
			
				|  |  | 
 |