IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Explorar el Código

New robust implementation of fixed_unigram_candidate_sampler() seems to be working.

master
Stanislaw Adaszewski hace 4 años
padre
commit
490b4f9281
Se han modificado 2 ficheros con 8 adiciones y 5 borrados
  1. +1
    -1
      src/triacontagon/cumcount.py
  2. +7
    -4
      src/triacontagon/sampling.py

+ 1
- 1
src/triacontagon/cumcount.py Ver fichero

@@ -9,7 +9,7 @@ def dfill(a):
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
torch.tensor([n])
])
print('b:',b)
# print('b:',b)
res = torch.arange(n)[b[:-1]]
res = torch.repeat_interleave(res, b[1:] - b[:-1])
return res


+ 7
- 4
src/triacontagon/sampling.py Ver fichero

@@ -77,14 +77,17 @@ def fixed_unigram_candidate_sampler(
mask = mask | repeated
updated = indices[~mask]
ofs = true_class_count[updated[:, 1]] + \
cumcount(updated[:, 1])
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
true_class_count[updated[:, 1]] = ofs + 1
if len(updated) > 0:
ofs = true_class_count[updated[:, 1]] + \
cumcount(updated[:, 1])
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
true_class_count[updated[:, 1]] = ofs + 1
result[indices[:, 0]] = candidates.transpose(0, 1)
indices = indices[mask]
return result
def fixed_unigram_candidate_sampler_slow(
true_classes: torch.Tensor,


Cargando…
Cancelar
Guardar