|  |  | @@ -0,0 +1,141 @@ | 
		
	
		
			
			|  |  |  | import tensorflow as tf | 
		
	
		
			
			|  |  |  | import numpy as np | 
		
	
		
			
			|  |  |  | from collections import defaultdict | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | import torch.utils.data | 
		
	
		
			
			|  |  |  | from typing import List, \ | 
		
	
		
			
			|  |  |  | Union | 
		
	
		
			
			|  |  |  | import decagon_pytorch.sampling | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_unigram_01(): | 
		
	
		
			
			|  |  |  | range_max = 7 | 
		
	
		
			
			|  |  |  | distortion = 0.75 | 
		
	
		
			
			|  |  |  | batch_size = 500 | 
		
	
		
			
			|  |  |  | unigrams = [ 1, 3, 2, 1, 2, 1, 3] | 
		
	
		
			
			|  |  |  | num_true = 1 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | 
		
	
		
			
			|  |  |  | for i in range(batch_size): | 
		
	
		
			
			|  |  |  | true_classes[i, 0] = i % range_max | 
		
	
		
			
			|  |  |  | true_classes = tf.convert_to_tensor(true_classes) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes=true_classes, | 
		
	
		
			
			|  |  |  | num_true=num_true, | 
		
	
		
			
			|  |  |  | num_sampled=batch_size, | 
		
	
		
			
			|  |  |  | unique=False, | 
		
	
		
			
			|  |  |  | range_max=range_max, | 
		
	
		
			
			|  |  |  | distortion=distortion, | 
		
	
		
			
			|  |  |  | unigrams=unigrams) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert neg_samples.shape == (batch_size,) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i in range(batch_size): | 
		
	
		
			
			|  |  |  | assert neg_samples[i] != true_classes[i, 0] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | counts = defaultdict(int) | 
		
	
		
			
			|  |  |  | with tf.Session() as sess: | 
		
	
		
			
			|  |  |  | neg_samples = neg_samples.eval() | 
		
	
		
			
			|  |  |  | for x in neg_samples: | 
		
	
		
			
			|  |  |  | counts[x] += 1 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('counts:', counts) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[0] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[2] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[3] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[4] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[4] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[5] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_unigram_02(): | 
		
	
		
			
			|  |  |  | range_max = 7 | 
		
	
		
			
			|  |  |  | distortion = 0.75 | 
		
	
		
			
			|  |  |  | batch_size = 500 | 
		
	
		
			
			|  |  |  | unigrams = [ 1, 3, 2, 1, 2, 1, 3] | 
		
	
		
			
			|  |  |  | num_true = 1 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | 
		
	
		
			
			|  |  |  | for i in range(batch_size): | 
		
	
		
			
			|  |  |  | true_classes[i, 0] = i % range_max | 
		
	
		
			
			|  |  |  | true_classes = torch.tensor(true_classes) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes=true_classes, | 
		
	
		
			
			|  |  |  | num_true=num_true, | 
		
	
		
			
			|  |  |  | num_samples=batch_size, | 
		
	
		
			
			|  |  |  | range_max=range_max, | 
		
	
		
			
			|  |  |  | distortion=distortion, | 
		
	
		
			
			|  |  |  | unigrams=unigrams) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert neg_samples.shape == (batch_size,) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i in range(batch_size): | 
		
	
		
			
			|  |  |  | assert neg_samples[i] != true_classes[i, 0] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | counts = defaultdict(int) | 
		
	
		
			
			|  |  |  | for x in neg_samples: | 
		
	
		
			
			|  |  |  | counts[x] += 1 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('counts:', counts) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[0] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[2] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[0] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[3] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[3] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[4] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[4] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert counts[5] < counts[1] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[2] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[4] and \ | 
		
	
		
			
			|  |  |  | counts[5] < counts[6] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_unigram_03(): | 
		
	
		
			
			|  |  |  | range_max = 7 | 
		
	
		
			
			|  |  |  | distortion = 0.75 | 
		
	
		
			
			|  |  |  | batch_size = 500 | 
		
	
		
			
			|  |  |  | unigrams = [ 1, 3, 2, 1, 2, 1, 3] | 
		
	
		
			
			|  |  |  | num_true = 1 | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | true_classes = np.zeros((batch_size, num_true), dtype=np.int64) | 
		
	
		
			
			|  |  |  | for i in range(batch_size): | 
		
	
		
			
			|  |  |  | true_classes[i, 0] = i % range_max | 
		
	
		
			
			|  |  |  | true_classes_tf = tf.convert_to_tensor(true_classes) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes=true_classes_tf, | 
		
	
		
			
			|  |  |  | num_true=num_true, | 
		
	
		
			
			|  |  |  | num_sampled=batch_size, | 
		
	
		
			
			|  |  |  | unique=False, | 
		
	
		
			
			|  |  |  | range_max=range_max, | 
		
	
		
			
			|  |  |  | distortion=distortion, | 
		
	
		
			
			|  |  |  | unigrams=unigrams) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | true_classes_torch = torch.tensor(true_classes) |