IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import torch
  7. import torch.utils.data
  8. from typing import List, \
  9. Union
  10. def fixed_unigram_candidate_sampler(
  11. true_classes: Union[np.array, torch.Tensor],
  12. unigrams: List[Union[int, float]],
  13. distortion: float = 1.):
  14. if isinstance(true_classes, torch.Tensor):
  15. true_classes = true_classes.detach().cpu().numpy()
  16. if len(true_classes.shape) != 2:
  17. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  18. num_samples = true_classes.shape[0]
  19. unigrams = np.array(unigrams)
  20. if distortion != 1.:
  21. unigrams = unigrams.astype(np.float64) ** distortion
  22. # print('unigrams:', unigrams)
  23. indices = np.arange(num_samples)
  24. result = np.zeros(num_samples, dtype=np.int64)
  25. while len(indices) > 0:
  26. # print('len(indices):', len(indices))
  27. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  28. candidates = np.array(list(sampler))
  29. candidates = np.reshape(candidates, (len(indices), 1))
  30. # print('candidates:', candidates)
  31. # print('true_classes:', true_classes[indices, :])
  32. result[indices] = candidates.T
  33. mask = (candidates == true_classes[indices, :])
  34. mask = mask.sum(1).astype(np.bool)
  35. # print('mask:', mask)
  36. indices = indices[mask]
  37. return torch.tensor(result)