diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 29ac224..8beac99 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -28,6 +28,11 @@ def fixed_unigram_candidate_sampler( if len(num_repeats.shape) != 1: raise ValueError('num_repeats must be 1D') + if torch.any(len(unigrams) - \ + (true_classes >= 0).sum(dim=1) < \ + num_repeats): + raise ValueError('Not enough classes to choose from') + num_rows = true_classes.shape[0] print('true_classes.shape:', true_classes.shape) # unigrams = np.array(unigrams)