IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

With this weird little trick the new fixed_unigram_candidate_sampler() finally seems to be fully robust and to actually work.

master
Stanislaw Adaszewski 3 years ago
parent
commit
356f3af3c3
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      src/triacontagon/sampling.py

+ 5
- 4
src/triacontagon/sampling.py View File

@@ -21,7 +21,7 @@ from itertools import product, \
from functools import reduce
def fixed_unigram_candidate_sampler_new(
def fixed_unigram_candidate_sampler(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
@@ -72,9 +72,10 @@ def fixed_unigram_candidate_sampler_new(
mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool)
can_cum = cumcount(candidates[:, 0])
# can_cum = cumcount(candidates[:, 0])
can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ])
ind_cum = cumcount(indices[:, 1])
repeated = (can_cum > 0) & (ind_cum > 0)
repeated = (can_diff == 0) & (ind_cum > 0)
# TODO: this is wrong, still requires work
mask = mask | repeated
@@ -141,7 +142,7 @@ def fixed_unigram_candidate_sampler_slow(
return torch.tensor(res)
def fixed_unigram_candidate_sampler(
def fixed_unigram_candidate_sampler_old(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,


Loading…
Cancel
Save