From 490b4f9281e3e4af34167d665d076ab8f4396501 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 21 Aug 2020 17:24:07 +0200 Subject: [PATCH] New robust implementation of fixed_unigram_candidate_sampler() seems to be working. --- src/triacontagon/cumcount.py | 2 +- src/triacontagon/sampling.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py index 164c784..33169f9 100644 --- a/src/triacontagon/cumcount.py +++ b/src/triacontagon/cumcount.py @@ -9,7 +9,7 @@ def dfill(a): torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, torch.tensor([n]) ]) - print('b:',b) + # print('b:',b) res = torch.arange(n)[b[:-1]] res = torch.repeat_interleave(res, b[1:] - b[:-1]) return res diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index d43afd0..3ae0ef3 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -77,14 +77,17 @@ def fixed_unigram_candidate_sampler( mask = mask | repeated updated = indices[~mask] - ofs = true_class_count[updated[:, 1]] + \ - cumcount(updated[:, 1]) - true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) - true_class_count[updated[:, 1]] = ofs + 1 + if len(updated) > 0: + ofs = true_class_count[updated[:, 1]] + \ + cumcount(updated[:, 1]) + true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) + true_class_count[updated[:, 1]] = ofs + 1 result[indices[:, 0]] = candidates.transpose(0, 1) indices = indices[mask] + return result + def fixed_unigram_candidate_sampler_slow( true_classes: torch.Tensor,