From 2ff358f7efe52c9c193fdab7b952837ee923a2f0 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 20 Aug 2020 22:06:40 +0200 Subject: [PATCH] Protect against endless loop. --- src/triacontagon/sampling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 29ac224..8beac99 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -28,6 +28,11 @@ def fixed_unigram_candidate_sampler( if len(num_repeats.shape) != 1: raise ValueError('num_repeats must be 1D') + if torch.any(len(unigrams) - \ + (true_classes >= 0).sum(dim=1) < \ + num_repeats): + raise ValueError('Not enough classes to choose from') + num_rows = true_classes.shape[0] print('true_classes.shape:', true_classes.shape) # unigrams = np.array(unigrams)