IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

39 行
1.4KB

  1. import numpy as np
  2. import torch
  3. import torch.utils.data
  4. from typing import List, \
  5. Union
  6. def fixed_unigram_candidate_sampler(
  7. true_classes: Union[np.array, torch.Tensor],
  8. num_true: int,
  9. num_samples: int,
  10. range_max: int,
  11. unigrams: List[Union[int, float]],
  12. distortion: float = 1.):
  13. if isinstance(true_classes, torch.Tensor):
  14. true_classes = true_classes.detach().cpu().numpy()
  15. if true_classes.shape != (num_samples, num_true):
  16. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  17. unigrams = np.array(unigrams)
  18. if distortion != 1.:
  19. unigrams = unigrams.astype(np.float64) ** distortion
  20. # print('unigrams:', unigrams)
  21. indices = np.arange(num_samples)
  22. result = np.zeros(num_samples, dtype=np.int64)
  23. while len(indices) > 0:
  24. # print('len(indices):', len(indices))
  25. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  26. candidates = np.array(list(sampler))
  27. candidates = np.reshape(candidates, (len(indices), 1))
  28. # print('candidates:', candidates)
  29. # print('true_classes:', true_classes[indices, :])
  30. result[indices] = candidates.T
  31. mask = (candidates == true_classes[indices, :])
  32. mask = mask.sum(1).astype(np.bool)
  33. # print('mask:', mask)
  34. indices = indices[mask]
  35. return result