|  |  | @@ -21,7 +21,7 @@ from itertools import product, \ | 
		
	
		
			
			|  |  |  | from functools import reduce | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler_new( | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes: torch.Tensor, | 
		
	
		
			
			|  |  |  | num_repeats: torch.Tensor, | 
		
	
		
			
			|  |  |  | unigrams: torch.Tensor, | 
		
	
	
		
			
				|  |  | @@ -72,9 +72,10 @@ def fixed_unigram_candidate_sampler_new( | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | can_cum = cumcount(candidates[:, 0]) | 
		
	
		
			
			|  |  |  | # can_cum = cumcount(candidates[:, 0]) | 
		
	
		
			
			|  |  |  | can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ]) | 
		
	
		
			
			|  |  |  | ind_cum = cumcount(indices[:, 1]) | 
		
	
		
			
			|  |  |  | repeated = (can_cum > 0) & (ind_cum > 0) | 
		
	
		
			
			|  |  |  | repeated = (can_diff == 0) & (ind_cum > 0) | 
		
	
		
			
			|  |  |  | # TODO: this is wrong, still requires work | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | mask = mask | repeated | 
		
	
	
		
			
				|  |  | @@ -141,7 +142,7 @@ def fixed_unigram_candidate_sampler_slow( | 
		
	
		
			
			|  |  |  | return torch.tensor(res) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler_old( | 
		
	
		
			
			|  |  |  | true_classes: torch.Tensor, | 
		
	
		
			
			|  |  |  | num_repeats: torch.Tensor, | 
		
	
		
			
			|  |  |  | unigrams: torch.Tensor, | 
		
	
	
		
			
				|  |  | 
 |