IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Add sampling tests.

master
Stanislaw Adaszewski 4年前
コミット
702079c7e9
1個のファイルの変更170行の追加0行の削除
  1. +170
    -0
      tests/icosagon/test_sampling.py

+ 170
- 0
tests/icosagon/test_sampling.py ファイルの表示

@@ -0,0 +1,170 @@
import tensorflow as tf
import numpy as np
from collections import defaultdict
import torch
import torch.utils.data
from typing import List, \
Union
import icosagon.sampling
import scipy.stats
def test_unigram_01():
range_max = 7
distortion = 0.75
batch_size = 500
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes = tf.convert_to_tensor(true_classes)
neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=true_classes,
num_true=num_true,
num_sampled=batch_size,
unique=False,
range_max=range_max,
distortion=distortion,
unigrams=unigrams)
assert neg_samples.shape == (batch_size,)
for i in range(batch_size):
assert neg_samples[i] != true_classes[i, 0]
counts = defaultdict(int)
with tf.Session() as sess:
neg_samples = neg_samples.eval()
for x in neg_samples:
counts[x] += 1
print('counts:', counts)
assert counts[0] < counts[1] and \
counts[0] < counts[2] and \
counts[0] < counts[4] and \
counts[0] < counts[6]
assert counts[2] < counts[1] and \
counts[0] < counts[6]
assert counts[3] < counts[1] and \
counts[3] < counts[2] and \
counts[3] < counts[4] and \
counts[3] < counts[6]
assert counts[4] < counts[1] and \
counts[4] < counts[6]
assert counts[5] < counts[1] and \
counts[5] < counts[2] and \
counts[5] < counts[4] and \
counts[5] < counts[6]
def test_unigram_02():
range_max = 7
distortion = 0.75
batch_size = 500
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes = torch.tensor(true_classes)
neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
true_classes=true_classes,
unigrams=unigrams,
distortion=distortion)
assert neg_samples.shape == (batch_size,)
for i in range(batch_size):
assert neg_samples[i] != true_classes[i, 0]
counts = defaultdict(int)
for x in neg_samples:
counts[x.item()] += 1
print('counts:', counts)
assert counts[0] < counts[1] and \
counts[0] < counts[2] and \
counts[0] < counts[4] and \
counts[0] < counts[6]
assert counts[2] < counts[1] and \
counts[0] < counts[6]
assert counts[3] < counts[1] and \
counts[3] < counts[2] and \
counts[3] < counts[4] and \
counts[3] < counts[6]
assert counts[4] < counts[1] and \
counts[4] < counts[6]
assert counts[5] < counts[1] and \
counts[5] < counts[2] and \
counts[5] < counts[4] and \
counts[5] < counts[6]
def test_unigram_03():
range_max = 7
distortion = 0.75
batch_size = 25
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes_tf = tf.convert_to_tensor(true_classes)
true_classes_torch = torch.tensor(true_classes)
counts_tf = defaultdict(list)
counts_torch = defaultdict(list)
for i in range(100):
neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=true_classes_tf,
num_true=num_true,
num_sampled=batch_size,
unique=False,
range_max=range_max,
distortion=distortion,
unigrams=unigrams)
counts = defaultdict(int)
with tf.Session() as sess:
neg_samples = neg_samples.eval()
for x in neg_samples:
counts[x.item()] += 1
for k, v in counts.items():
counts_tf[k].append(v)
neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
true_classes=true_classes,
distortion=distortion,
unigrams=unigrams)
counts = defaultdict(int)
for x in neg_samples:
counts[x.item()] += 1
for k, v in counts.items():
counts_torch[k].append(v)
for i in range(range_max):
print('counts_tf[%d]:' % i, counts_tf[i])
print('counts_torch[%d]:' % i, counts_torch[i])
for i in range(range_max):
statistic, pvalue = scipy.stats.ttest_ind(counts_tf[i], counts_torch[i])
assert pvalue * range_max > .05

読み込み中…
キャンセル
保存