| @@ -0,0 +1,170 @@ | |||||
| import tensorflow as tf | |||||
| import numpy as np | |||||
| from collections import defaultdict | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from typing import List, \ | |||||
| Union | |||||
| import icosagon.sampling | |||||
| import scipy.stats | |||||
| def test_unigram_01(): | |||||
| range_max = 7 | |||||
| distortion = 0.75 | |||||
| batch_size = 500 | |||||
| unigrams = [ 1, 3, 2, 1, 2, 1, 3] | |||||
| num_true = 1 | |||||
| true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | |||||
| for i in range(batch_size): | |||||
| true_classes[i, 0] = i % range_max | |||||
| true_classes = tf.convert_to_tensor(true_classes) | |||||
| neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( | |||||
| true_classes=true_classes, | |||||
| num_true=num_true, | |||||
| num_sampled=batch_size, | |||||
| unique=False, | |||||
| range_max=range_max, | |||||
| distortion=distortion, | |||||
| unigrams=unigrams) | |||||
| assert neg_samples.shape == (batch_size,) | |||||
| for i in range(batch_size): | |||||
| assert neg_samples[i] != true_classes[i, 0] | |||||
| counts = defaultdict(int) | |||||
| with tf.Session() as sess: | |||||
| neg_samples = neg_samples.eval() | |||||
| for x in neg_samples: | |||||
| counts[x] += 1 | |||||
| print('counts:', counts) | |||||
| assert counts[0] < counts[1] and \ | |||||
| counts[0] < counts[2] and \ | |||||
| counts[0] < counts[4] and \ | |||||
| counts[0] < counts[6] | |||||
| assert counts[2] < counts[1] and \ | |||||
| counts[0] < counts[6] | |||||
| assert counts[3] < counts[1] and \ | |||||
| counts[3] < counts[2] and \ | |||||
| counts[3] < counts[4] and \ | |||||
| counts[3] < counts[6] | |||||
| assert counts[4] < counts[1] and \ | |||||
| counts[4] < counts[6] | |||||
| assert counts[5] < counts[1] and \ | |||||
| counts[5] < counts[2] and \ | |||||
| counts[5] < counts[4] and \ | |||||
| counts[5] < counts[6] | |||||
| def test_unigram_02(): | |||||
| range_max = 7 | |||||
| distortion = 0.75 | |||||
| batch_size = 500 | |||||
| unigrams = [ 1, 3, 2, 1, 2, 1, 3] | |||||
| num_true = 1 | |||||
| true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | |||||
| for i in range(batch_size): | |||||
| true_classes[i, 0] = i % range_max | |||||
| true_classes = torch.tensor(true_classes) | |||||
| neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler( | |||||
| true_classes=true_classes, | |||||
| unigrams=unigrams, | |||||
| distortion=distortion) | |||||
| assert neg_samples.shape == (batch_size,) | |||||
| for i in range(batch_size): | |||||
| assert neg_samples[i] != true_classes[i, 0] | |||||
| counts = defaultdict(int) | |||||
| for x in neg_samples: | |||||
| counts[x.item()] += 1 | |||||
| print('counts:', counts) | |||||
| assert counts[0] < counts[1] and \ | |||||
| counts[0] < counts[2] and \ | |||||
| counts[0] < counts[4] and \ | |||||
| counts[0] < counts[6] | |||||
| assert counts[2] < counts[1] and \ | |||||
| counts[0] < counts[6] | |||||
| assert counts[3] < counts[1] and \ | |||||
| counts[3] < counts[2] and \ | |||||
| counts[3] < counts[4] and \ | |||||
| counts[3] < counts[6] | |||||
| assert counts[4] < counts[1] and \ | |||||
| counts[4] < counts[6] | |||||
| assert counts[5] < counts[1] and \ | |||||
| counts[5] < counts[2] and \ | |||||
| counts[5] < counts[4] and \ | |||||
| counts[5] < counts[6] | |||||
| def test_unigram_03(): | |||||
| range_max = 7 | |||||
| distortion = 0.75 | |||||
| batch_size = 25 | |||||
| unigrams = [ 1, 3, 2, 1, 2, 1, 3] | |||||
| num_true = 1 | |||||
| true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | |||||
| for i in range(batch_size): | |||||
| true_classes[i, 0] = i % range_max | |||||
| true_classes_tf = tf.convert_to_tensor(true_classes) | |||||
| true_classes_torch = torch.tensor(true_classes) | |||||
| counts_tf = defaultdict(list) | |||||
| counts_torch = defaultdict(list) | |||||
| for i in range(100): | |||||
| neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( | |||||
| true_classes=true_classes_tf, | |||||
| num_true=num_true, | |||||
| num_sampled=batch_size, | |||||
| unique=False, | |||||
| range_max=range_max, | |||||
| distortion=distortion, | |||||
| unigrams=unigrams) | |||||
| counts = defaultdict(int) | |||||
| with tf.Session() as sess: | |||||
| neg_samples = neg_samples.eval() | |||||
| for x in neg_samples: | |||||
| counts[x.item()] += 1 | |||||
| for k, v in counts.items(): | |||||
| counts_tf[k].append(v) | |||||
| neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler( | |||||
| true_classes=true_classes, | |||||
| distortion=distortion, | |||||
| unigrams=unigrams) | |||||
| counts = defaultdict(int) | |||||
| for x in neg_samples: | |||||
| counts[x.item()] += 1 | |||||
| for k, v in counts.items(): | |||||
| counts_torch[k].append(v) | |||||
| for i in range(range_max): | |||||
| print('counts_tf[%d]:' % i, counts_tf[i]) | |||||
| print('counts_torch[%d]:' % i, counts_torch[i]) | |||||
| for i in range(range_max): | |||||
| statistic, pvalue = scipy.stats.ttest_ind(counts_tf[i], counts_torch[i]) | |||||
| assert pvalue * range_max > .05 | |||||