|
|
@@ -21,7 +21,7 @@ from itertools import product, \ |
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
|
def fixed_unigram_candidate_sampler_new(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
@@ -57,6 +57,8 @@ def fixed_unigram_candidate_sampler( |
|
|
|
result = torch.zeros(len(indices), dtype=torch.long)
|
|
|
|
|
|
|
|
while len(indices) > 0:
|
|
|
|
print(len(indices))
|
|
|
|
|
|
|
|
candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
|
|
|
|
candidates = torch.tensor(list(candidates)).view(-1, 1)
|
|
|
|
|
|
|
@@ -73,6 +75,7 @@ def fixed_unigram_candidate_sampler( |
|
|
|
can_cum = cumcount(candidates[:, 0])
|
|
|
|
ind_cum = cumcount(indices[:, 1])
|
|
|
|
repeated = (can_cum > 0) & (ind_cum > 0)
|
|
|
|
# TODO: this is wrong, still requires work
|
|
|
|
|
|
|
|
mask = mask | repeated
|
|
|
|
|
|
|
@@ -138,7 +141,7 @@ def fixed_unigram_candidate_sampler_slow( |
|
|
|
return torch.tensor(res)
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler_old(
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
|