IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

173 行
4.8KB

  1. import tensorflow as tf
  2. import numpy as np
  3. from collections import defaultdict
  4. import torch
  5. import torch.utils.data
  6. from typing import List, \
  7. Union
  8. import decagon_pytorch.sampling
  9. import scipy.stats
  10. def test_unigram_01():
  11. range_max = 7
  12. distortion = 0.75
  13. batch_size = 500
  14. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  15. num_true = 1
  16. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  17. for i in range(batch_size):
  18. true_classes[i, 0] = i % range_max
  19. true_classes = tf.convert_to_tensor(true_classes)
  20. neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
  21. true_classes=true_classes,
  22. num_true=num_true,
  23. num_sampled=batch_size,
  24. unique=False,
  25. range_max=range_max,
  26. distortion=distortion,
  27. unigrams=unigrams)
  28. assert neg_samples.shape == (batch_size,)
  29. for i in range(batch_size):
  30. assert neg_samples[i] != true_classes[i, 0]
  31. counts = defaultdict(int)
  32. with tf.Session() as sess:
  33. neg_samples = neg_samples.eval()
  34. for x in neg_samples:
  35. counts[x] += 1
  36. print('counts:', counts)
  37. assert counts[0] < counts[1] and \
  38. counts[0] < counts[2] and \
  39. counts[0] < counts[4] and \
  40. counts[0] < counts[6]
  41. assert counts[2] < counts[1] and \
  42. counts[0] < counts[6]
  43. assert counts[3] < counts[1] and \
  44. counts[3] < counts[2] and \
  45. counts[3] < counts[4] and \
  46. counts[3] < counts[6]
  47. assert counts[4] < counts[1] and \
  48. counts[4] < counts[6]
  49. assert counts[5] < counts[1] and \
  50. counts[5] < counts[2] and \
  51. counts[5] < counts[4] and \
  52. counts[5] < counts[6]
  53. def test_unigram_02():
  54. range_max = 7
  55. distortion = 0.75
  56. batch_size = 500
  57. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  58. num_true = 1
  59. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  60. for i in range(batch_size):
  61. true_classes[i, 0] = i % range_max
  62. true_classes = torch.tensor(true_classes)
  63. neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
  64. true_classes=true_classes,
  65. num_samples=batch_size,
  66. distortion=distortion,
  67. unigrams=unigrams)
  68. assert neg_samples.shape == (batch_size,)
  69. for i in range(batch_size):
  70. assert neg_samples[i] != true_classes[i, 0]
  71. counts = defaultdict(int)
  72. for x in neg_samples:
  73. counts[x] += 1
  74. print('counts:', counts)
  75. assert counts[0] < counts[1] and \
  76. counts[0] < counts[2] and \
  77. counts[0] < counts[4] and \
  78. counts[0] < counts[6]
  79. assert counts[2] < counts[1] and \
  80. counts[0] < counts[6]
  81. assert counts[3] < counts[1] and \
  82. counts[3] < counts[2] and \
  83. counts[3] < counts[4] and \
  84. counts[3] < counts[6]
  85. assert counts[4] < counts[1] and \
  86. counts[4] < counts[6]
  87. assert counts[5] < counts[1] and \
  88. counts[5] < counts[2] and \
  89. counts[5] < counts[4] and \
  90. counts[5] < counts[6]
  91. def test_unigram_03():
  92. range_max = 7
  93. distortion = 0.75
  94. batch_size = 25
  95. unigrams = [ 1, 3, 2, 1, 2, 1, 3]
  96. num_true = 1
  97. true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
  98. for i in range(batch_size):
  99. true_classes[i, 0] = i % range_max
  100. true_classes_tf = tf.convert_to_tensor(true_classes)
  101. true_classes_torch = torch.tensor(true_classes)
  102. counts_tf = defaultdict(list)
  103. counts_torch = defaultdict(list)
  104. for i in range(100):
  105. neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
  106. true_classes=true_classes_tf,
  107. num_true=num_true,
  108. num_sampled=batch_size,
  109. unique=False,
  110. range_max=range_max,
  111. distortion=distortion,
  112. unigrams=unigrams)
  113. counts = defaultdict(int)
  114. with tf.Session() as sess:
  115. neg_samples = neg_samples.eval()
  116. for x in neg_samples:
  117. counts[x] += 1
  118. for k, v in counts.items():
  119. counts_tf[k].append(v)
  120. neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
  121. true_classes=true_classes,
  122. num_samples=batch_size,
  123. distortion=distortion,
  124. unigrams=unigrams)
  125. counts = defaultdict(int)
  126. for x in neg_samples:
  127. counts[x] += 1
  128. for k, v in counts.items():
  129. counts_torch[k].append(v)
  130. for i in range(range_max):
  131. print('counts_tf[%d]:' % i, counts_tf[i])
  132. print('counts_torch[%d]:' % i, counts_torch[i])
  133. for i in range(range_max):
  134. statistic, pvalue = scipy.stats.ttest_ind(counts_tf[i], counts_torch[i])
  135. assert pvalue * range_max > .05