|  |  | @@ -14,6 +14,60 @@ from .data import Data, \ | 
		
	
		
			
			|  |  |  | EdgeType | 
		
	
		
			
			|  |  |  | from .cumcount import cumcount | 
		
	
		
			
			|  |  |  | import time | 
		
	
		
			
			|  |  |  | import multiprocessing | 
		
	
		
			
			|  |  |  | import multiprocessing.pool | 
		
	
		
			
			|  |  |  | from itertools import product, \ | 
		
	
		
			
			|  |  |  | repeat | 
		
	
		
			
			|  |  |  | from functools import reduce | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler_slow( | 
		
	
		
			
			|  |  |  | true_classes: torch.Tensor, | 
		
	
		
			
			|  |  |  | num_repeats: torch.Tensor, | 
		
	
		
			
			|  |  |  | unigrams: torch.Tensor, | 
		
	
		
			
			|  |  |  | distortion: float = 1.) -> torch.Tensor: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert isinstance(true_classes, torch.Tensor) | 
		
	
		
			
			|  |  |  | assert isinstance(num_repeats, torch.Tensor) | 
		
	
		
			
			|  |  |  | assert isinstance(unigrams, torch.Tensor) | 
		
	
		
			
			|  |  |  | distortion = float(distortion) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if len(true_classes.shape) != 2: | 
		
	
		
			
			|  |  |  | raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if len(num_repeats.shape) != 1: | 
		
	
		
			
			|  |  |  | raise ValueError('num_repeats must be 1D') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if torch.any((unigrams > 0).sum() - \ | 
		
	
		
			
			|  |  |  | (true_classes >= 0).sum(dim=1) < \ | 
		
	
		
			
			|  |  |  | num_repeats): | 
		
	
		
			
			|  |  |  | raise ValueError('Not enough classes to choose from') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res = [] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if distortion != 1.: | 
		
	
		
			
			|  |  |  | unigrams = unigrams.to(torch.float64) | 
		
	
		
			
			|  |  |  | unigrams = unigrams ** distortion | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fun(i): | 
		
	
		
			
			|  |  |  | if i and i % 100 == 0: | 
		
	
		
			
			|  |  |  | print(i) | 
		
	
		
			
			|  |  |  | if num_repeats[i] == 0: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | pos = torch.flatten(true_classes[i, :]) | 
		
	
		
			
			|  |  |  | pos = pos[pos >= 0] | 
		
	
		
			
			|  |  |  | w = unigrams.clone().detach() | 
		
	
		
			
			|  |  |  | w[pos] = 0 | 
		
	
		
			
			|  |  |  | sampler = torch.utils.data.WeightedRandomSampler(w, | 
		
	
		
			
			|  |  |  | num_repeats[i].item(), replacement=False) | 
		
	
		
			
			|  |  |  | res = list(sampler) | 
		
	
		
			
			|  |  |  | return res | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | with multiprocessing.pool.ThreadPool() as p: | 
		
	
		
			
			|  |  |  | res = p.map(fun, range(len(num_repeats))) | 
		
	
		
			
			|  |  |  | res = reduce(list.__add__, res, []) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return torch.tensor(res) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler( | 
		
	
	
		
			
				|  |  | @@ -61,6 +115,12 @@ def fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | print('result:', result) | 
		
	
		
			
			|  |  |  | mask = (candidates == true_classes[indices[:, 1], :]) | 
		
	
		
			
			|  |  |  | mask = mask.sum(1).to(torch.bool) | 
		
	
		
			
			|  |  |  | # append_true_classes = torch.full(( len(true_classes), ), -1) | 
		
	
		
			
			|  |  |  | # append_true_classes[~mask] = torch.flatten(candidates)[~mask] | 
		
	
		
			
			|  |  |  | # true_classes = torch.cat([ | 
		
	
		
			
			|  |  |  | #     append_true_classes.view(-1, 1), | 
		
	
		
			
			|  |  |  | #     true_classes | 
		
	
		
			
			|  |  |  | # ], dim=1) | 
		
	
		
			
			|  |  |  | print('mask:', mask) | 
		
	
		
			
			|  |  |  | indices = indices[mask] | 
		
	
		
			
			|  |  |  | # result[indices] = 0 | 
		
	
	
		
			
				|  |  | 
 |