IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

New robust implementation of fixed_unigram_candidate_sampler() seems to be working.

master
Stanislaw Adaszewski 4 years ago
parent
commit
490b4f9281
2 changed files with 8 additions and 5 deletions
  1. +1
    -1
      src/triacontagon/cumcount.py
  2. +7
    -4
      src/triacontagon/sampling.py

+ 1
- 1
src/triacontagon/cumcount.py View File

@@ -9,7 +9,7 @@ def dfill(a):
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
torch.tensor([n]) torch.tensor([n])
]) ])
print('b:',b)
# print('b:',b)
res = torch.arange(n)[b[:-1]] res = torch.arange(n)[b[:-1]]
res = torch.repeat_interleave(res, b[1:] - b[:-1]) res = torch.repeat_interleave(res, b[1:] - b[:-1])
return res return res


+ 7
- 4
src/triacontagon/sampling.py View File

@@ -77,14 +77,17 @@ def fixed_unigram_candidate_sampler(
mask = mask | repeated mask = mask | repeated
updated = indices[~mask] updated = indices[~mask]
ofs = true_class_count[updated[:, 1]] + \
cumcount(updated[:, 1])
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
true_class_count[updated[:, 1]] = ofs + 1
if len(updated) > 0:
ofs = true_class_count[updated[:, 1]] + \
cumcount(updated[:, 1])
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
true_class_count[updated[:, 1]] = ofs + 1
result[indices[:, 0]] = candidates.transpose(0, 1) result[indices[:, 0]] = candidates.transpose(0, 1)
indices = indices[mask] indices = indices[mask]
return result
def fixed_unigram_candidate_sampler_slow( def fixed_unigram_candidate_sampler_slow(
true_classes: torch.Tensor, true_classes: torch.Tensor,


Loading…
Cancel
Save