IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

142 行
3.8KB

  1. import tensorflow as tf
  2. import numpy as np
  3. from collections import defaultdict
  4. import torch
  5. import torch.utils.data
  6. from typing import List, \
  7. Union
  8. import decagon_pytorch.sampling
  9. def test_unigram_01():
  10. range_max = 7
  11. distortion = 0.75
  12. batch_size = 500
  13. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  14. num_true = 1
  15. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  16. for i in range(batch_size):
  17. true_classes[i, 0] = i % range_max
  18. true_classes = tf.convert_to_tensor(true_classes)
  19. neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
  20. true_classes=true_classes,
  21. num_true=num_true,
  22. num_sampled=batch_size,
  23. unique=False,
  24. range_max=range_max,
  25. distortion=distortion,
  26. unigrams=unigrams)
  27. assert neg_samples.shape == (batch_size,)
  28. for i in range(batch_size):
  29. assert neg_samples[i] != true_classes[i, 0]
  30. counts = defaultdict(int)
  31. with tf.Session() as sess:
  32. neg_samples = neg_samples.eval()
  33. for x in neg_samples:
  34. counts[x] += 1
  35. print('counts:', counts)
  36. assert counts[0] < counts[1] and \
  37. counts[0] < counts[2] and \
  38. counts[0] < counts[4] and \
  39. counts[0] < counts[6]
  40. assert counts[2] < counts[1] and \
  41. counts[0] < counts[6]
  42. assert counts[3] < counts[1] and \
  43. counts[3] < counts[2] and \
  44. counts[3] < counts[4] and \
  45. counts[3] < counts[6]
  46. assert counts[4] < counts[1] and \
  47. counts[4] < counts[6]
  48. assert counts[5] < counts[1] and \
  49. counts[5] < counts[2] and \
  50. counts[5] < counts[4] and \
  51. counts[5] < counts[6]
  52. def test_unigram_02():
  53. range_max = 7
  54. distortion = 0.75
  55. batch_size = 500
  56. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  57. num_true = 1
  58. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  59. for i in range(batch_size):
  60. true_classes[i, 0] = i % range_max
  61. true_classes = torch.tensor(true_classes)
  62. neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
  63. true_classes=true_classes,
  64. num_true=num_true,
  65. num_samples=batch_size,
  66. range_max=range_max,
  67. distortion=distortion,
  68. unigrams=unigrams)
  69. assert neg_samples.shape == (batch_size,)
  70. for i in range(batch_size):
  71. assert neg_samples[i] != true_classes[i, 0]
  72. counts = defaultdict(int)
  73. for x in neg_samples:
  74. counts[x] += 1
  75. print('counts:', counts)
  76. assert counts[0] < counts[1] and \
  77. counts[0] < counts[2] and \
  78. counts[0] < counts[4] and \
  79. counts[0] < counts[6]
  80. assert counts[2] < counts[1] and \
  81. counts[0] < counts[6]
  82. assert counts[3] < counts[1] and \
  83. counts[3] < counts[2] and \
  84. counts[3] < counts[4] and \
  85. counts[3] < counts[6]
  86. assert counts[4] < counts[1] and \
  87. counts[4] < counts[6]
  88. assert counts[5] < counts[1] and \
  89. counts[5] < counts[2] and \
  90. counts[5] < counts[4] and \
  91. counts[5] < counts[6]
  92. def test_unigram_03():
  93. range_max = 7
  94. distortion = 0.75
  95. batch_size = 500
  96. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  97. num_true = 1
  98. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  99. for i in range(batch_size):
  100. true_classes[i, 0] = i % range_max
  101. true_classes_tf = tf.convert_to_tensor(true_classes)
  102. neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
  103. true_classes=true_classes_tf,
  104. num_true=num_true,
  105. num_sampled=batch_size,
  106. unique=False,
  107. range_max=range_max,
  108. distortion=distortion,
  109. unigrams=unigrams)
  110. true_classes_torch = torch.tensor(true_classes)