|
123456789101112131415161718192021222324252627282930313233343536 |
- import numpy as np
- import torch
- import torch.utils.data
- from typing import List, \
- Union
-
-
- def fixed_unigram_candidate_sampler(
- true_classes: Union[np.array, torch.Tensor],
- num_samples: int,
- unigrams: List[Union[int, float]],
- distortion: float = 1.):
-
- if isinstance(true_classes, torch.Tensor):
- true_classes = true_classes.detach().cpu().numpy()
- if true_classes.shape[0] != num_samples:
- raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
- unigrams = np.array(unigrams)
- if distortion != 1.:
- unigrams = unigrams.astype(np.float64) ** distortion
- # print('unigrams:', unigrams)
- indices = np.arange(num_samples)
- result = np.zeros(num_samples, dtype=np.int64)
- while len(indices) > 0:
- # print('len(indices):', len(indices))
- sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
- candidates = np.array(list(sampler))
- candidates = np.reshape(candidates, (len(indices), 1))
- # print('candidates:', candidates)
- # print('true_classes:', true_classes[indices, :])
- result[indices] = candidates.T
- mask = (candidates == true_classes[indices, :])
- mask = mask.sum(1).astype(np.bool)
- # print('mask:', mask)
- indices = indices[mask]
- return result
|