From 7ab97e2fb65330c3cd9282d59e47f727ed37e4f2 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 2 Jun 2020 12:08:34 +0200 Subject: [PATCH] Add test_unigram_03(). --- src/decagon_pytorch/sampling.py | 4 +- .../{test_unigram.py => test_sampling.py} | 57 ++++++++++++++----- 2 files changed, 45 insertions(+), 16 deletions(-) rename tests/decagon_pytorch/{test_unigram.py => test_sampling.py} (68%) diff --git a/src/decagon_pytorch/sampling.py b/src/decagon_pytorch/sampling.py index a029626..c2357a3 100644 --- a/src/decagon_pytorch/sampling.py +++ b/src/decagon_pytorch/sampling.py @@ -7,15 +7,13 @@ from typing import List, \ def fixed_unigram_candidate_sampler( true_classes: Union[np.array, torch.Tensor], - num_true: int, num_samples: int, - range_max: int, unigrams: List[Union[int, float]], distortion: float = 1.): if isinstance(true_classes, torch.Tensor): true_classes = true_classes.detach().cpu().numpy() - if true_classes.shape != (num_samples, num_true): + if true_classes.shape[0] != num_samples: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') unigrams = np.array(unigrams) if distortion != 1.: diff --git a/tests/decagon_pytorch/test_unigram.py b/tests/decagon_pytorch/test_sampling.py similarity index 68% rename from tests/decagon_pytorch/test_unigram.py rename to tests/decagon_pytorch/test_sampling.py index c08c1bd..028d5b5 100644 --- a/tests/decagon_pytorch/test_unigram.py +++ b/tests/decagon_pytorch/test_sampling.py @@ -6,6 +6,7 @@ import torch.utils.data from typing import List, \ Union import decagon_pytorch.sampling +import scipy.stats def test_unigram_01(): @@ -78,9 +79,7 @@ def test_unigram_02(): neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler( true_classes=true_classes, - num_true=num_true, num_samples=batch_size, - range_max=range_max, distortion=distortion, unigrams=unigrams) @@ -120,22 +119,54 @@ def test_unigram_02(): def test_unigram_03(): range_max = 7 distortion = 0.75 - batch_size = 500 + batch_size = 25 unigrams = [ 1, 3, 2, 1, 2, 1, 3] num_true = 1 true_classes = np.zeros((batch_size, num_true), dtype=np.int64) for i in range(batch_size): true_classes[i, 0] = i % range_max - true_classes_tf = tf.convert_to_tensor(true_classes) - - neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( - true_classes=true_classes_tf, - num_true=num_true, - num_sampled=batch_size, - unique=False, - range_max=range_max, - distortion=distortion, - unigrams=unigrams) + true_classes_tf = tf.convert_to_tensor(true_classes) true_classes_torch = torch.tensor(true_classes) + + counts_tf = defaultdict(list) + counts_torch = defaultdict(list) + + for i in range(100): + neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( + true_classes=true_classes_tf, + num_true=num_true, + num_sampled=batch_size, + unique=False, + range_max=range_max, + distortion=distortion, + unigrams=unigrams) + + counts = defaultdict(int) + with tf.Session() as sess: + neg_samples = neg_samples.eval() + for x in neg_samples: + counts[x] += 1 + for k, v in counts.items(): + counts_tf[k].append(v) + + neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler( + true_classes=true_classes, + num_samples=batch_size, + distortion=distortion, + unigrams=unigrams) + + counts = defaultdict(int) + for x in neg_samples: + counts[x] += 1 + for k, v in counts.items(): + counts_torch[k].append(v) + + for i in range(range_max): + print('counts_tf[%d]:' % i, counts_tf[i]) + print('counts_torch[%d]:' % i, counts_torch[i]) + + for i in range(range_max): + statistic, pvalue = scipy.stats.ttest_ind(counts_tf[i], counts_torch[i]) + assert pvalue * range_max > .05