| @@ -21,7 +21,7 @@ from itertools import product, \ | |||||
| from functools import reduce | from functools import reduce | ||||
| def fixed_unigram_candidate_sampler( | |||||
| def fixed_unigram_candidate_sampler_new( | |||||
| true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
| num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
| unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||
| @@ -57,6 +57,8 @@ def fixed_unigram_candidate_sampler( | |||||
| result = torch.zeros(len(indices), dtype=torch.long) | result = torch.zeros(len(indices), dtype=torch.long) | ||||
| while len(indices) > 0: | while len(indices) > 0: | ||||
| print(len(indices)) | |||||
| candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | ||||
| candidates = torch.tensor(list(candidates)).view(-1, 1) | candidates = torch.tensor(list(candidates)).view(-1, 1) | ||||
| @@ -73,6 +75,7 @@ def fixed_unigram_candidate_sampler( | |||||
| can_cum = cumcount(candidates[:, 0]) | can_cum = cumcount(candidates[:, 0]) | ||||
| ind_cum = cumcount(indices[:, 1]) | ind_cum = cumcount(indices[:, 1]) | ||||
| repeated = (can_cum > 0) & (ind_cum > 0) | repeated = (can_cum > 0) & (ind_cum > 0) | ||||
| # TODO: this is wrong, still requires work | |||||
| mask = mask | repeated | mask = mask | repeated | ||||
| @@ -138,7 +141,7 @@ def fixed_unigram_candidate_sampler_slow( | |||||
| return torch.tensor(res) | return torch.tensor(res) | ||||
| def fixed_unigram_candidate_sampler_old( | |||||
| def fixed_unigram_candidate_sampler( | |||||
| true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
| num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
| unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||