From f2121bef539c0ac9c23e72e908bd417064f966a8 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 2 Jun 2020 11:50:51 +0200 Subject: [PATCH] Add fixed_unigram_candidate_sampler(). --- src/decagon_pytorch/sampling.py | 38 +++++++ tests/decagon_pytorch/test_unigram.py | 141 ++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 src/decagon_pytorch/sampling.py create mode 100644 tests/decagon_pytorch/test_unigram.py diff --git a/src/decagon_pytorch/sampling.py b/src/decagon_pytorch/sampling.py new file mode 100644 index 0000000..a029626 --- /dev/null +++ b/src/decagon_pytorch/sampling.py @@ -0,0 +1,38 @@ +import numpy as np +import torch +import torch.utils.data +from typing import List, \ + Union + + +def fixed_unigram_candidate_sampler( + true_classes: Union[np.array, torch.Tensor], + num_true: int, + num_samples: int, + range_max: int, + unigrams: List[Union[int, float]], + distortion: float = 1.): + + if isinstance(true_classes, torch.Tensor): + true_classes = true_classes.detach().cpu().numpy() + if true_classes.shape != (num_samples, num_true): + raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') + unigrams = np.array(unigrams) + if distortion != 1.: + unigrams = unigrams.astype(np.float64) ** distortion + # print('unigrams:', unigrams) + indices = np.arange(num_samples) + result = np.zeros(num_samples, dtype=np.int64) + while len(indices) > 0: + # print('len(indices):', len(indices)) + sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) + candidates = np.array(list(sampler)) + candidates = np.reshape(candidates, (len(indices), 1)) + # print('candidates:', candidates) + # print('true_classes:', true_classes[indices, :]) + result[indices] = candidates.T + mask = (candidates == true_classes[indices, :]) + mask = mask.sum(1).astype(np.bool) + # print('mask:', mask) + indices = indices[mask] + return result diff --git a/tests/decagon_pytorch/test_unigram.py b/tests/decagon_pytorch/test_unigram.py new file mode 100644 index 0000000..c08c1bd --- /dev/null +++ b/tests/decagon_pytorch/test_unigram.py @@ -0,0 +1,141 @@ +import tensorflow as tf +import numpy as np +from collections import defaultdict +import torch +import torch.utils.data +from typing import List, \ + Union +import decagon_pytorch.sampling + + +def test_unigram_01(): + range_max = 7 + distortion = 0.75 + batch_size = 500 + unigrams = [ 1, 3, 2, 1, 2, 1, 3] + num_true = 1 + + true_classes = np.zeros((batch_size, num_true), dtype=np.int64) + for i in range(batch_size): + true_classes[i, 0] = i % range_max + true_classes = tf.convert_to_tensor(true_classes) + + neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( + true_classes=true_classes, + num_true=num_true, + num_sampled=batch_size, + unique=False, + range_max=range_max, + distortion=distortion, + unigrams=unigrams) + + assert neg_samples.shape == (batch_size,) + + for i in range(batch_size): + assert neg_samples[i] != true_classes[i, 0] + + counts = defaultdict(int) + with tf.Session() as sess: + neg_samples = neg_samples.eval() + for x in neg_samples: + counts[x] += 1 + + print('counts:', counts) + + assert counts[0] < counts[1] and \ + counts[0] < counts[2] and \ + counts[0] < counts[4] and \ + counts[0] < counts[6] + + assert counts[2] < counts[1] and \ + counts[0] < counts[6] + + assert counts[3] < counts[1] and \ + counts[3] < counts[2] and \ + counts[3] < counts[4] and \ + counts[3] < counts[6] + + assert counts[4] < counts[1] and \ + counts[4] < counts[6] + + assert counts[5] < counts[1] and \ + counts[5] < counts[2] and \ + counts[5] < counts[4] and \ + counts[5] < counts[6] + + +def test_unigram_02(): + range_max = 7 + distortion = 0.75 + batch_size = 500 + unigrams = [ 1, 3, 2, 1, 2, 1, 3] + num_true = 1 + + true_classes = np.zeros((batch_size, num_true), dtype=np.int64) + for i in range(batch_size): + true_classes[i, 0] = i % range_max + true_classes = torch.tensor(true_classes) + + neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler( + true_classes=true_classes, + num_true=num_true, + num_samples=batch_size, + range_max=range_max, + distortion=distortion, + unigrams=unigrams) + + assert neg_samples.shape == (batch_size,) + + for i in range(batch_size): + assert neg_samples[i] != true_classes[i, 0] + + counts = defaultdict(int) + for x in neg_samples: + counts[x] += 1 + + print('counts:', counts) + + assert counts[0] < counts[1] and \ + counts[0] < counts[2] and \ + counts[0] < counts[4] and \ + counts[0] < counts[6] + + assert counts[2] < counts[1] and \ + counts[0] < counts[6] + + assert counts[3] < counts[1] and \ + counts[3] < counts[2] and \ + counts[3] < counts[4] and \ + counts[3] < counts[6] + + assert counts[4] < counts[1] and \ + counts[4] < counts[6] + + assert counts[5] < counts[1] and \ + counts[5] < counts[2] and \ + counts[5] < counts[4] and \ + counts[5] < counts[6] + + +def test_unigram_03(): + range_max = 7 + distortion = 0.75 + batch_size = 500 + unigrams = [ 1, 3, 2, 1, 2, 1, 3] + num_true = 1 + + true_classes = np.zeros((batch_size, num_true), dtype=np.int64) + for i in range(batch_size): + true_classes[i, 0] = i % range_max + true_classes_tf = tf.convert_to_tensor(true_classes) + + neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( + true_classes=true_classes_tf, + num_true=num_true, + num_sampled=batch_size, + unique=False, + range_max=range_max, + distortion=distortion, + unigrams=unigrams) + + true_classes_torch = torch.tensor(true_classes)