|
|
@@ -14,6 +14,60 @@ from .data import Data, \ |
|
|
|
EdgeType
|
|
|
|
from .cumcount import cumcount
|
|
|
|
import time
|
|
|
|
import multiprocessing
|
|
|
|
import multiprocessing.pool
|
|
|
|
from itertools import product, \
|
|
|
|
repeat
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler_slow(
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
|
distortion: float = 1.) -> torch.Tensor:
|
|
|
|
|
|
|
|
assert isinstance(true_classes, torch.Tensor)
|
|
|
|
assert isinstance(num_repeats, torch.Tensor)
|
|
|
|
assert isinstance(unigrams, torch.Tensor)
|
|
|
|
distortion = float(distortion)
|
|
|
|
|
|
|
|
if len(true_classes.shape) != 2:
|
|
|
|
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
|
|
|
|
|
|
|
|
if len(num_repeats.shape) != 1:
|
|
|
|
raise ValueError('num_repeats must be 1D')
|
|
|
|
|
|
|
|
if torch.any((unigrams > 0).sum() - \
|
|
|
|
(true_classes >= 0).sum(dim=1) < \
|
|
|
|
num_repeats):
|
|
|
|
raise ValueError('Not enough classes to choose from')
|
|
|
|
|
|
|
|
res = []
|
|
|
|
|
|
|
|
if distortion != 1.:
|
|
|
|
unigrams = unigrams.to(torch.float64)
|
|
|
|
unigrams = unigrams ** distortion
|
|
|
|
|
|
|
|
def fun(i):
|
|
|
|
if i and i % 100 == 0:
|
|
|
|
print(i)
|
|
|
|
if num_repeats[i] == 0:
|
|
|
|
return []
|
|
|
|
pos = torch.flatten(true_classes[i, :])
|
|
|
|
pos = pos[pos >= 0]
|
|
|
|
w = unigrams.clone().detach()
|
|
|
|
w[pos] = 0
|
|
|
|
sampler = torch.utils.data.WeightedRandomSampler(w,
|
|
|
|
num_repeats[i].item(), replacement=False)
|
|
|
|
res = list(sampler)
|
|
|
|
return res
|
|
|
|
|
|
|
|
with multiprocessing.pool.ThreadPool() as p:
|
|
|
|
res = p.map(fun, range(len(num_repeats)))
|
|
|
|
res = reduce(list.__add__, res, [])
|
|
|
|
|
|
|
|
return torch.tensor(res)
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
@@ -61,6 +115,12 @@ def fixed_unigram_candidate_sampler( |
|
|
|
print('result:', result)
|
|
|
|
mask = (candidates == true_classes[indices[:, 1], :])
|
|
|
|
mask = mask.sum(1).to(torch.bool)
|
|
|
|
# append_true_classes = torch.full(( len(true_classes), ), -1)
|
|
|
|
# append_true_classes[~mask] = torch.flatten(candidates)[~mask]
|
|
|
|
# true_classes = torch.cat([
|
|
|
|
# append_true_classes.view(-1, 1),
|
|
|
|
# true_classes
|
|
|
|
# ], dim=1)
|
|
|
|
print('mask:', mask)
|
|
|
|
indices = indices[mask]
|
|
|
|
# result[indices] = 0
|
|
|
|