IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

New fixed_unigram_candidate_sampler() still requires work.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
d7d442c5e3
共有 1 個文件被更改,包括 5 次插入2 次删除
  1. +5
    -2
      src/triacontagon/sampling.py

+ 5
- 2
src/triacontagon/sampling.py 查看文件

@@ -21,7 +21,7 @@ from itertools import product, \
from functools import reduce
def fixed_unigram_candidate_sampler(
def fixed_unigram_candidate_sampler_new(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
@@ -57,6 +57,8 @@ def fixed_unigram_candidate_sampler(
result = torch.zeros(len(indices), dtype=torch.long)
while len(indices) > 0:
print(len(indices))
candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = torch.tensor(list(candidates)).view(-1, 1)
@@ -73,6 +75,7 @@ def fixed_unigram_candidate_sampler(
can_cum = cumcount(candidates[:, 0])
ind_cum = cumcount(indices[:, 1])
repeated = (can_cum > 0) & (ind_cum > 0)
# TODO: this is wrong, still requires work
mask = mask | repeated
@@ -138,7 +141,7 @@ def fixed_unigram_candidate_sampler_slow(
return torch.tensor(res)
def fixed_unigram_candidate_sampler_old(
def fixed_unigram_candidate_sampler(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,


Loading…
取消
儲存