|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import tensorflow as tf
- import numpy as np
- from collections import defaultdict
- import torch
- import torch.utils.data
- from typing import List, \
- Union
- import icosagon.sampling
- import scipy.stats
-
-
- def test_unigram_01():
- range_max = 7
- distortion = 0.75
- batch_size = 500
- unigrams = [ 1, 3, 2, 1, 2, 1, 3]
- num_true = 1
-
- true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
- for i in range(batch_size):
- true_classes[i, 0] = i % range_max
- true_classes = tf.convert_to_tensor(true_classes)
-
- neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
- true_classes=true_classes,
- num_true=num_true,
- num_sampled=batch_size,
- unique=False,
- range_max=range_max,
- distortion=distortion,
- unigrams=unigrams)
-
- assert neg_samples.shape == (batch_size,)
-
- for i in range(batch_size):
- assert neg_samples[i] != true_classes[i, 0]
-
- counts = defaultdict(int)
- with tf.Session() as sess:
- neg_samples = neg_samples.eval()
- for x in neg_samples:
- counts[x] += 1
-
- print('counts:', counts)
-
- assert counts[0] < counts[1] and \
- counts[0] < counts[2] and \
- counts[0] < counts[4] and \
- counts[0] < counts[6]
-
- assert counts[2] < counts[1] and \
- counts[0] < counts[6]
-
- assert counts[3] < counts[1] and \
- counts[3] < counts[2] and \
- counts[3] < counts[4] and \
- counts[3] < counts[6]
-
- assert counts[4] < counts[1] and \
- counts[4] < counts[6]
-
- assert counts[5] < counts[1] and \
- counts[5] < counts[2] and \
- counts[5] < counts[4] and \
- counts[5] < counts[6]
-
-
- def test_unigram_02():
- range_max = 7
- distortion = 0.75
- batch_size = 500
- unigrams = [ 1, 3, 2, 1, 2, 1, 3]
- num_true = 1
-
- true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
- for i in range(batch_size):
- true_classes[i, 0] = i % range_max
- true_classes = torch.tensor(true_classes)
-
- neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
- true_classes=true_classes,
- unigrams=unigrams,
- distortion=distortion)
-
- assert neg_samples.shape == (batch_size,)
-
- for i in range(batch_size):
- assert neg_samples[i] != true_classes[i, 0]
-
- counts = defaultdict(int)
- for x in neg_samples:
- counts[x.item()] += 1
-
- print('counts:', counts)
-
- assert counts[0] < counts[1] and \
- counts[0] < counts[2] and \
- counts[0] < counts[4] and \
- counts[0] < counts[6]
-
- assert counts[2] < counts[1] and \
- counts[0] < counts[6]
-
- assert counts[3] < counts[1] and \
- counts[3] < counts[2] and \
- counts[3] < counts[4] and \
- counts[3] < counts[6]
-
- assert counts[4] < counts[1] and \
- counts[4] < counts[6]
-
- assert counts[5] < counts[1] and \
- counts[5] < counts[2] and \
- counts[5] < counts[4] and \
- counts[5] < counts[6]
-
-
- def test_unigram_03():
- range_max = 7
- distortion = 0.75
- batch_size = 2500
- unigrams = [ 1, 3, 2, 1, 2, 1, 3]
- num_true = 1
-
- true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
- for i in range(batch_size):
- true_classes[i, 0] = i % range_max
-
- true_classes_tf = tf.convert_to_tensor(true_classes)
- true_classes_torch = torch.tensor(true_classes)
-
- counts_tf = torch.zeros(range_max)
- counts_torch = torch.zeros(range_max)
-
- for i in range(10):
- neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
- true_classes=true_classes_tf,
- num_true=num_true,
- num_sampled=batch_size,
- unique=False,
- range_max=range_max,
- distortion=distortion,
- unigrams=unigrams)
-
- counts = torch.zeros(range_max)
- with tf.Session() as sess:
- neg_samples = neg_samples.eval()
- for x in neg_samples:
- counts[x.item()] += 1
- counts_tf += counts
-
- neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
- true_classes=true_classes,
- distortion=distortion,
- unigrams=unigrams)
-
- counts = torch.zeros(range_max)
- for x in neg_samples:
- counts[x.item()] += 1
- counts_torch += counts
-
- print('counts_tf:', counts_tf)
- print('counts_torch:', counts_torch)
-
- distance = scipy.stats.wasserstein_distance(counts_tf, counts_torch)
- assert distance < 2000
|