IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

fixed_unigram_candidate_sampler() still requires some work to be perfect.

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
6cf8251539
2 geänderte Dateien mit 61 neuen und 1 gelöschten Zeilen
  1. +1
    -1
      src/triacontagon/loop.py
  2. +60
    -0
      src/triacontagon/sampling.py

+ 1
- 1
src/triacontagon/loop.py Datei anzeigen

@@ -1,6 +1,6 @@
from .model import Model, \
TrainingBatch
from .batch import Batcher
from .batch import DualBatcher
from .sampling import negative_sample_data
from .data import Data
import torch


+ 60
- 0
src/triacontagon/sampling.py Datei anzeigen

@@ -14,6 +14,60 @@ from .data import Data, \
EdgeType
from .cumcount import cumcount
import time
import multiprocessing
import multiprocessing.pool
from itertools import product, \
repeat
from functools import reduce
def fixed_unigram_candidate_sampler_slow(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
distortion: float = 1.) -> torch.Tensor:
assert isinstance(true_classes, torch.Tensor)
assert isinstance(num_repeats, torch.Tensor)
assert isinstance(unigrams, torch.Tensor)
distortion = float(distortion)
if len(true_classes.shape) != 2:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
if len(num_repeats.shape) != 1:
raise ValueError('num_repeats must be 1D')
if torch.any((unigrams > 0).sum() - \
(true_classes >= 0).sum(dim=1) < \
num_repeats):
raise ValueError('Not enough classes to choose from')
res = []
if distortion != 1.:
unigrams = unigrams.to(torch.float64)
unigrams = unigrams ** distortion
def fun(i):
if i and i % 100 == 0:
print(i)
if num_repeats[i] == 0:
return []
pos = torch.flatten(true_classes[i, :])
pos = pos[pos >= 0]
w = unigrams.clone().detach()
w[pos] = 0
sampler = torch.utils.data.WeightedRandomSampler(w,
num_repeats[i].item(), replacement=False)
res = list(sampler)
return res
with multiprocessing.pool.ThreadPool() as p:
res = p.map(fun, range(len(num_repeats)))
res = reduce(list.__add__, res, [])
return torch.tensor(res)
def fixed_unigram_candidate_sampler(
@@ -61,6 +115,12 @@ def fixed_unigram_candidate_sampler(
print('result:', result)
mask = (candidates == true_classes[indices[:, 1], :])
mask = mask.sum(1).to(torch.bool)
# append_true_classes = torch.full(( len(true_classes), ), -1)
# append_true_classes[~mask] = torch.flatten(candidates)[~mask]
# true_classes = torch.cat([
# append_true_classes.view(-1, 1),
# true_classes
# ], dim=1)
print('mask:', mask)
indices = indices[mask]
# result[indices] = 0


Laden…
Abbrechen
Speichern