|  |  | @@ -0,0 +1,91 @@ | 
		
	
		
			
			|  |  |  | from icosagon.trainprep import PreparedData, \ | 
		
	
		
			
			|  |  |  | PreparedRelationFamily, \ | 
		
	
		
			
			|  |  |  | PreparedRelationType, \ | 
		
	
		
			
			|  |  |  | _empty_edge_list_tvt | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class BatchedData(PreparedData): | 
		
	
		
			
			|  |  |  | def __init__(self, *args, **kwargs): | 
		
	
		
			
			|  |  |  | super().__init__(*args, **kwargs) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def batched_data_skeleton(data: PreparedData) -> BatchedData: | 
		
	
		
			
			|  |  |  | if not isinstance(data, PreparedData): | 
		
	
		
			
			|  |  |  | raise TypeError('data must be an instance of PreparedData') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | fam_skels = [] | 
		
	
		
			
			|  |  |  | for fam in data.relation_families: | 
		
	
		
			
			|  |  |  | rel_types_skel = [] | 
		
	
		
			
			|  |  |  | for rel in fam.relation_types: | 
		
	
		
			
			|  |  |  | rel_skel = PreparedRelationType(rel.name, | 
		
	
		
			
			|  |  |  | rel.node_type_row, rel.node_type_column, | 
		
	
		
			
			|  |  |  | rel.adjacency_matrix, rel.adjacency_matrix_backward, | 
		
	
		
			
			|  |  |  | _empty_edge_list_tvt(), _empty_edge_list_tvt(), | 
		
	
		
			
			|  |  |  | _empty_edge_list_tvt(), _empty_edge_list_tvt()) | 
		
	
		
			
			|  |  |  | rel_types_skel.append(rel_skel) | 
		
	
		
			
			|  |  |  | fam_skels.append(PreparedRelationFamily(fam.data, fam.name, | 
		
	
		
			
			|  |  |  | fam.node_type_row, fam.node_type_column, | 
		
	
		
			
			|  |  |  | fam.is_symmetric, fam.decoder_class, | 
		
	
		
			
			|  |  |  | rel_types_skel)) | 
		
	
		
			
			|  |  |  | return BatchedData(data.node_types, fam_skels) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class DataBatcher(object): | 
		
	
		
			
			|  |  |  | def __init__(self, data: PreparedData, batch_size: int) -> None: | 
		
	
		
			
			|  |  |  | self._check_params(data, batch_size) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.data = data | 
		
	
		
			
			|  |  |  | self.batch_size = batch_size | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # def batched_data_iter(self, fam_idx: int, rel_idx: int, | 
		
	
		
			
			|  |  |  | #     part_type: str) -> BatchedData: | 
		
	
		
			
			|  |  |  | # | 
		
	
		
			
			|  |  |  | #     rel = self.data.relation_families[fam_idx].relation_types[rel_idx] | 
		
	
		
			
			|  |  |  | # | 
		
	
		
			
			|  |  |  | #     edges = getattr(rel.edges_pos, part_type) | 
		
	
		
			
			|  |  |  | #     for m in range(0, len(edges), self.batch_size): | 
		
	
		
			
			|  |  |  | #         batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | #         setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos, | 
		
	
		
			
			|  |  |  | #             part_type, edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | #         yield batched_data | 
		
	
		
			
			|  |  |  | # | 
		
	
		
			
			|  |  |  | #     edges = getattr(rel.edges_neg, part_type) | 
		
	
		
			
			|  |  |  | #     for m in range(0, len(edges), self.batch_size): | 
		
	
		
			
			|  |  |  | #         batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | #         setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg, | 
		
	
		
			
			|  |  |  | #             part_type, edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | #         yield batched_data | 
		
	
		
			
			|  |  |  | # | 
		
	
		
			
			|  |  |  | #     edges = getattr(rel.edges_pos_back, part_type) | 
		
	
		
			
			|  |  |  | #     for m in range(0, len(edges), self.batch_size): | 
		
	
		
			
			|  |  |  | #         batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | #         setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back, | 
		
	
		
			
			|  |  |  | #             part_type, edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | #         yield batched_data | 
		
	
		
			
			|  |  |  | # | 
		
	
		
			
			|  |  |  | #     edges = getattr(rel.edges_neg_back, part_type) | 
		
	
		
			
			|  |  |  | #     for m in range(0, len(), self.batch_size): | 
		
	
		
			
			|  |  |  | #         batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | #         setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back, | 
		
	
		
			
			|  |  |  | #             edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | #         yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __iter__(self) -> BatchedData: | 
		
	
		
			
			|  |  |  | for i, fam in enumerate(self.data.relation_families): | 
		
	
		
			
			|  |  |  | for k, rel in enumerate(fam.relation_types): | 
		
	
		
			
			|  |  |  | for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: | 
		
	
		
			
			|  |  |  | for part_type in ['train', 'val', 'test']: | 
		
	
		
			
			|  |  |  | edges = getattr(getattr(rel, edge_type), part_type) | 
		
	
		
			
			|  |  |  | for m in range(0, len(edges), self.batch_size): | 
		
	
		
			
			|  |  |  | batched_data = batched_data_skeleton(self.data) | 
		
	
		
			
			|  |  |  | setattr(getattr(batched_data.relation_families[i].relation_types[k], | 
		
	
		
			
			|  |  |  | edge_type), part_type, edges[m : m + self.batch_size]) | 
		
	
		
			
			|  |  |  | yield batched_data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | @staticmethod | 
		
	
		
			
			|  |  |  | def _check_params(data, batch_size): | 
		
	
		
			
			|  |  |  | if not isinstance(data, PreparedData): | 
		
	
		
			
			|  |  |  | raise TypeError('data must be an instance of PreparedData') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(batch_size, int): | 
		
	
		
			
			|  |  |  | raise TypeError('batch_size must be an int') |