|
|
@@ -77,14 +77,17 @@ def fixed_unigram_candidate_sampler( |
|
|
|
mask = mask | repeated
|
|
|
|
|
|
|
|
updated = indices[~mask]
|
|
|
|
ofs = true_class_count[updated[:, 1]] + \
|
|
|
|
cumcount(updated[:, 1])
|
|
|
|
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
|
|
|
|
true_class_count[updated[:, 1]] = ofs + 1
|
|
|
|
if len(updated) > 0:
|
|
|
|
ofs = true_class_count[updated[:, 1]] + \
|
|
|
|
cumcount(updated[:, 1])
|
|
|
|
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
|
|
|
|
true_class_count[updated[:, 1]] = ofs + 1
|
|
|
|
|
|
|
|
result[indices[:, 0]] = candidates.transpose(0, 1)
|
|
|
|
indices = indices[mask]
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler_slow(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|