From 356f3af3c38ef87ef40c59ae96f2e52def5e7f2d Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 21 Aug 2020 18:51:27 +0200 Subject: [PATCH] With this weird little trick the new fixed_unigram_candidate_sampler() finally seems to be fully robust and to actually work. --- src/triacontagon/sampling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index b88f90d..73d7cf2 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -21,7 +21,7 @@ from itertools import product, \ from functools import reduce -def fixed_unigram_candidate_sampler_new( +def fixed_unigram_candidate_sampler( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, @@ -72,9 +72,10 @@ def fixed_unigram_candidate_sampler_new( mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) - can_cum = cumcount(candidates[:, 0]) + # can_cum = cumcount(candidates[:, 0]) + can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ]) ind_cum = cumcount(indices[:, 1]) - repeated = (can_cum > 0) & (ind_cum > 0) + repeated = (can_diff == 0) & (ind_cum > 0) # TODO: this is wrong, still requires work mask = mask | repeated @@ -141,7 +142,7 @@ def fixed_unigram_candidate_sampler_slow( return torch.tensor(res) -def fixed_unigram_candidate_sampler( +def fixed_unigram_candidate_sampler_old( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor,