|
|
@@ -21,7 +21,7 @@ from itertools import product, \ |
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler_new(
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
@@ -72,9 +72,10 @@ def fixed_unigram_candidate_sampler_new( |
|
|
|
|
|
|
|
mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool)
|
|
|
|
|
|
|
|
can_cum = cumcount(candidates[:, 0])
|
|
|
|
# can_cum = cumcount(candidates[:, 0])
|
|
|
|
can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ])
|
|
|
|
ind_cum = cumcount(indices[:, 1])
|
|
|
|
repeated = (can_cum > 0) & (ind_cum > 0)
|
|
|
|
repeated = (can_diff == 0) & (ind_cum > 0)
|
|
|
|
# TODO: this is wrong, still requires work
|
|
|
|
|
|
|
|
mask = mask | repeated
|
|
|
@@ -141,7 +142,7 @@ def fixed_unigram_candidate_sampler_slow( |
|
|
|
return torch.tensor(res)
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
|
def fixed_unigram_candidate_sampler_old(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
|