|
@@ -28,6 +28,11 @@ def fixed_unigram_candidate_sampler( |
|
|
if len(num_repeats.shape) != 1:
|
|
|
if len(num_repeats.shape) != 1:
|
|
|
raise ValueError('num_repeats must be 1D')
|
|
|
raise ValueError('num_repeats must be 1D')
|
|
|
|
|
|
|
|
|
|
|
|
if torch.any(len(unigrams) - \
|
|
|
|
|
|
(true_classes >= 0).sum(dim=1) < \
|
|
|
|
|
|
num_repeats):
|
|
|
|
|
|
raise ValueError('Not enough classes to choose from')
|
|
|
|
|
|
|
|
|
num_rows = true_classes.shape[0]
|
|
|
num_rows = true_classes.shape[0]
|
|
|
print('true_classes.shape:', true_classes.shape)
|
|
|
print('true_classes.shape:', true_classes.shape)
|
|
|
# unigrams = np.array(unigrams)
|
|
|
# unigrams = np.array(unigrams)
|
|
|