diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py index 164c784..33169f9 100644 --- a/src/triacontagon/cumcount.py +++ b/src/triacontagon/cumcount.py @@ -9,7 +9,7 @@ def dfill(a): torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, torch.tensor([n]) ]) - print('b:',b) + # print('b:',b) res = torch.arange(n)[b[:-1]] res = torch.repeat_interleave(res, b[1:] - b[:-1]) return res diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index d43afd0..3ae0ef3 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -77,14 +77,17 @@ def fixed_unigram_candidate_sampler( mask = mask | repeated updated = indices[~mask] - ofs = true_class_count[updated[:, 1]] + \ - cumcount(updated[:, 1]) - true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) - true_class_count[updated[:, 1]] = ofs + 1 + if len(updated) > 0: + ofs = true_class_count[updated[:, 1]] + \ + cumcount(updated[:, 1]) + true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) + true_class_count[updated[:, 1]] = ofs + 1 result[indices[:, 0]] = candidates.transpose(0, 1) indices = indices[mask] + return result + def fixed_unigram_candidate_sampler_slow( true_classes: torch.Tensor,