|  |  | @@ -1,6 +1,7 @@ | 
		
	
		
			
			|  |  |  | from .data import Data | 
		
	
		
			
			|  |  |  | from .model import TrainingBatch | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | from functools import reduce | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _shuffle(x: torch.Tensor) -> torch.Tensor: | 
		
	
	
		
			
				|  |  | @@ -8,6 +9,66 @@ def _shuffle(x: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | return x[order] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _same_data_org(pos_data: Data, neg_data: Data): | 
		
	
		
			
			|  |  |  | if len(pos_data.vertex_types) != len(neg_data.vertex_types): | 
		
	
		
			
			|  |  |  | return False | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \ | 
		
	
		
			
			|  |  |  | and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \ | 
		
	
		
			
			|  |  |  | for i in range(len(pos_data.vertex_types)) ] | 
		
	
		
			
			|  |  |  | if not all(test): | 
		
	
		
			
			|  |  |  | return False | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not set(pos_data.edge_types.keys()) == \ | 
		
	
		
			
			|  |  |  | set(neg_data.edge_types.keys()): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return False | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | test = [ pos_data.edge_types[i].name == \ | 
		
	
		
			
			|  |  |  | neg_data.edge_types[i].name \ | 
		
	
		
			
			|  |  |  | and pos_data.edge_types[i].vertex_type_row == \ | 
		
	
		
			
			|  |  |  | neg_data.edge_types[i].vertex_type_row \ | 
		
	
		
			
			|  |  |  | and pos_data.edge_types[i].vertex_type_column == \ | 
		
	
		
			
			|  |  |  | neg_data.edge_types[i].vertex_type_column \ | 
		
	
		
			
			|  |  |  | and len(pos_data.edge_types[i].adjacency_matrices) == \ | 
		
	
		
			
			|  |  |  | len(neg_data.edge_types[i].adjacency_matrices) \ | 
		
	
		
			
			|  |  |  | for i in pos_data.edge_types.keys() ] | 
		
	
		
			
			|  |  |  | if not all(test): | 
		
	
		
			
			|  |  |  | return False | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \ | 
		
	
		
			
			|  |  |  | len(neg_data.edge_types[i].adjacency_matrices[k].values()) \ | 
		
	
		
			
			|  |  |  | for k in range(len(pos_data.edge_types[i])) ] \ | 
		
	
		
			
			|  |  |  | for i in pos_data.edge_types.keys() ] | 
		
	
		
			
			|  |  |  | test = reduce(list.__add__, test) | 
		
	
		
			
			|  |  |  | if not all(test): | 
		
	
		
			
			|  |  |  | return False | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return True | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class DualBatcher(object): | 
		
	
		
			
			|  |  |  | def __init__(self, pos_data: Data, neg_data: Data, | 
		
	
		
			
			|  |  |  | batch_size: int=512, shuffle: bool=True) -> None: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(pos_data, Data): | 
		
	
		
			
			|  |  |  | raise TypeError('pos_data must be an instance of Data') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(neg_data, Data): | 
		
	
		
			
			|  |  |  | raise TypeError('neg_data must be an instance of Data') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not _same_data_org(pos_data, neg_data): | 
		
	
		
			
			|  |  |  | raise ValueError('pos_data and neg_data must have the same organization') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.pos_data = pos_data | 
		
	
		
			
			|  |  |  | self.neg_data = neg_data | 
		
	
		
			
			|  |  |  | self.batch_size = int(batch_size) | 
		
	
		
			
			|  |  |  | self.shuffle = bool(shuffle) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __iter__(self): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class Batcher(object): | 
		
	
		
			
			|  |  |  | def __init__(self, data: Data, batch_size: int=512, | 
		
	
		
			
			|  |  |  | shuffle: bool=True) -> None: | 
		
	
	
		
			
				|  |  | 
 |