IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampling.py 1.4KB

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. import torch
  3. import torch.utils.data
  4. from typing import List, \
  5. Union
  6. def fixed_unigram_candidate_sampler(
  7. true_classes: Union[np.array, torch.Tensor],
  8. num_samples: int,
  9. unigrams: List[Union[int, float]],
  10. distortion: float = 1.):
  11. if isinstance(true_classes, torch.Tensor):
  12. true_classes = true_classes.detach().cpu().numpy()
  13. if true_classes.shape[0] != num_samples:
  14. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  15. unigrams = np.array(unigrams)
  16. if distortion != 1.:
  17. unigrams = unigrams.astype(np.float64) ** distortion
  18. # print('unigrams:', unigrams)
  19. indices = np.arange(num_samples)
  20. result = np.zeros(num_samples, dtype=np.int64)
  21. while len(indices) > 0:
  22. # print('len(indices):', len(indices))
  23. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  24. candidates = np.array(list(sampler))
  25. candidates = np.reshape(candidates, (len(indices), 1))
  26. # print('candidates:', candidates)
  27. # print('true_classes:', true_classes[indices, :])
  28. result[indices] = candidates.T
  29. mask = (candidates == true_classes[indices, :])
  30. mask = mask.sum(1).astype(np.bool)
  31. # print('mask:', mask)
  32. indices = indices[mask]
  33. return result