diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 3ae0ef3..b88f90d 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -21,7 +21,7 @@ from itertools import product, \ from functools import reduce -def fixed_unigram_candidate_sampler( +def fixed_unigram_candidate_sampler_new( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, @@ -57,6 +57,8 @@ def fixed_unigram_candidate_sampler( result = torch.zeros(len(indices), dtype=torch.long) while len(indices) > 0: + print(len(indices)) + candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) candidates = torch.tensor(list(candidates)).view(-1, 1) @@ -73,6 +75,7 @@ def fixed_unigram_candidate_sampler( can_cum = cumcount(candidates[:, 0]) ind_cum = cumcount(indices[:, 1]) repeated = (can_cum > 0) & (ind_cum > 0) + # TODO: this is wrong, still requires work mask = mask | repeated @@ -138,7 +141,7 @@ def fixed_unigram_candidate_sampler_slow( return torch.tensor(res) -def fixed_unigram_candidate_sampler_old( +def fixed_unigram_candidate_sampler( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor,