IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед на файлове

Add fixed_unigram_candidate_sampler().

master
Stanislaw Adaszewski преди 4 години
родител
ревизия
f2121bef53
променени са 2 файла, в които са добавени 179 реда и са изтрити 0 реда
  1. +38
    -0
      src/decagon_pytorch/sampling.py
  2. +141
    -0
      tests/decagon_pytorch/test_unigram.py

+ 38
- 0
src/decagon_pytorch/sampling.py Целия файл

@@ -0,0 +1,38 @@
import numpy as np
import torch
import torch.utils.data
from typing import List, \
Union
def fixed_unigram_candidate_sampler(
true_classes: Union[np.array, torch.Tensor],
num_true: int,
num_samples: int,
range_max: int,
unigrams: List[Union[int, float]],
distortion: float = 1.):
if isinstance(true_classes, torch.Tensor):
true_classes = true_classes.detach().cpu().numpy()
if true_classes.shape != (num_samples, num_true):
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
unigrams = np.array(unigrams)
if distortion != 1.:
unigrams = unigrams.astype(np.float64) ** distortion
# print('unigrams:', unigrams)
indices = np.arange(num_samples)
result = np.zeros(num_samples, dtype=np.int64)
while len(indices) > 0:
# print('len(indices):', len(indices))
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = np.array(list(sampler))
candidates = np.reshape(candidates, (len(indices), 1))
# print('candidates:', candidates)
# print('true_classes:', true_classes[indices, :])
result[indices] = candidates.T
mask = (candidates == true_classes[indices, :])
mask = mask.sum(1).astype(np.bool)
# print('mask:', mask)
indices = indices[mask]
return result

+ 141
- 0
tests/decagon_pytorch/test_unigram.py Целия файл

@@ -0,0 +1,141 @@
import tensorflow as tf
import numpy as np
from collections import defaultdict
import torch
import torch.utils.data
from typing import List, \
Union
import decagon_pytorch.sampling
def test_unigram_01():
range_max = 7
distortion = 0.75
batch_size = 500
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes = tf.convert_to_tensor(true_classes)
neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=true_classes,
num_true=num_true,
num_sampled=batch_size,
unique=False,
range_max=range_max,
distortion=distortion,
unigrams=unigrams)
assert neg_samples.shape == (batch_size,)
for i in range(batch_size):
assert neg_samples[i] != true_classes[i, 0]
counts = defaultdict(int)
with tf.Session() as sess:
neg_samples = neg_samples.eval()
for x in neg_samples:
counts[x] += 1
print('counts:', counts)
assert counts[0] < counts[1] and \
counts[0] < counts[2] and \
counts[0] < counts[4] and \
counts[0] < counts[6]
assert counts[2] < counts[1] and \
counts[0] < counts[6]
assert counts[3] < counts[1] and \
counts[3] < counts[2] and \
counts[3] < counts[4] and \
counts[3] < counts[6]
assert counts[4] < counts[1] and \
counts[4] < counts[6]
assert counts[5] < counts[1] and \
counts[5] < counts[2] and \
counts[5] < counts[4] and \
counts[5] < counts[6]
def test_unigram_02():
range_max = 7
distortion = 0.75
batch_size = 500
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes = torch.tensor(true_classes)
neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
true_classes=true_classes,
num_true=num_true,
num_samples=batch_size,
range_max=range_max,
distortion=distortion,
unigrams=unigrams)
assert neg_samples.shape == (batch_size,)
for i in range(batch_size):
assert neg_samples[i] != true_classes[i, 0]
counts = defaultdict(int)
for x in neg_samples:
counts[x] += 1
print('counts:', counts)
assert counts[0] < counts[1] and \
counts[0] < counts[2] and \
counts[0] < counts[4] and \
counts[0] < counts[6]
assert counts[2] < counts[1] and \
counts[0] < counts[6]
assert counts[3] < counts[1] and \
counts[3] < counts[2] and \
counts[3] < counts[4] and \
counts[3] < counts[6]
assert counts[4] < counts[1] and \
counts[4] < counts[6]
assert counts[5] < counts[1] and \
counts[5] < counts[2] and \
counts[5] < counts[4] and \
counts[5] < counts[6]
def test_unigram_03():
range_max = 7
distortion = 0.75
batch_size = 500
unigrams = [ 1, 3, 2, 1, 2, 1, 3]
num_true = 1
true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
for i in range(batch_size):
true_classes[i, 0] = i % range_max
true_classes_tf = tf.convert_to_tensor(true_classes)
neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=true_classes_tf,
num_true=num_true,
num_sampled=batch_size,
unique=False,
range_max=range_max,
distortion=distortion,
unigrams=unigrams)
true_classes_torch = torch.tensor(true_classes)

Loading…
Отказ
Запис