# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # import numpy as np import torch import torch.utils.data from typing import List, \ Union def fixed_unigram_candidate_sampler( true_classes: Union[np.array, torch.Tensor], unigrams: List[Union[int, float]], distortion: float = 1.): if isinstance(true_classes, torch.Tensor): true_classes = true_classes.detach().cpu().numpy() if isinstance(unigrams, torch.Tensor): unigrams = unigrams.detach().cpu().numpy() if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') num_samples = true_classes.shape[0] unigrams = np.array(unigrams) if distortion != 1.: unigrams = unigrams.astype(np.float64) ** distortion # print('unigrams:', unigrams) indices = np.arange(num_samples) result = np.zeros(num_samples, dtype=np.int64) while len(indices) > 0: # print('len(indices):', len(indices)) sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) candidates = np.array(list(sampler)) candidates = np.reshape(candidates, (len(indices), 1)) # print('candidates:', candidates) # print('true_classes:', true_classes[indices, :]) result[indices] = candidates.T mask = (candidates == true_classes[indices, :]) mask = mask.sum(1).astype(np.bool) # print('mask:', mask) indices = indices[mask] return torch.tensor(result)