diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py index e5f96bf..870c07e 100644 --- a/src/triacontagon/loop.py +++ b/src/triacontagon/loop.py @@ -1,6 +1,6 @@ from .model import Model, \ TrainingBatch -from .batch import Batcher +from .batch import DualBatcher from .sampling import negative_sample_data from .data import Data import torch diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index bcec040..ab033dd 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -14,6 +14,60 @@ from .data import Data, \ EdgeType from .cumcount import cumcount import time +import multiprocessing +import multiprocessing.pool +from itertools import product, \ + repeat +from functools import reduce + + +def fixed_unigram_candidate_sampler_slow( + true_classes: torch.Tensor, + num_repeats: torch.Tensor, + unigrams: torch.Tensor, + distortion: float = 1.) -> torch.Tensor: + + assert isinstance(true_classes, torch.Tensor) + assert isinstance(num_repeats, torch.Tensor) + assert isinstance(unigrams, torch.Tensor) + distortion = float(distortion) + + if len(true_classes.shape) != 2: + raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') + + if len(num_repeats.shape) != 1: + raise ValueError('num_repeats must be 1D') + + if torch.any((unigrams > 0).sum() - \ + (true_classes >= 0).sum(dim=1) < \ + num_repeats): + raise ValueError('Not enough classes to choose from') + + res = [] + + if distortion != 1.: + unigrams = unigrams.to(torch.float64) + unigrams = unigrams ** distortion + + def fun(i): + if i and i % 100 == 0: + print(i) + if num_repeats[i] == 0: + return [] + pos = torch.flatten(true_classes[i, :]) + pos = pos[pos >= 0] + w = unigrams.clone().detach() + w[pos] = 0 + sampler = torch.utils.data.WeightedRandomSampler(w, + num_repeats[i].item(), replacement=False) + res = list(sampler) + return res + + with multiprocessing.pool.ThreadPool() as p: + res = p.map(fun, range(len(num_repeats))) + res = reduce(list.__add__, res, []) + + return torch.tensor(res) def fixed_unigram_candidate_sampler( @@ -61,6 +115,12 @@ def fixed_unigram_candidate_sampler( print('result:', result) mask = (candidates == true_classes[indices[:, 1], :]) mask = mask.sum(1).to(torch.bool) + # append_true_classes = torch.full(( len(true_classes), ), -1) + # append_true_classes[~mask] = torch.flatten(candidates)[~mask] + # true_classes = torch.cat([ + # append_true_classes.view(-1, 1), + # true_classes + # ], dim=1) print('mask:', mask) indices = indices[mask] # result[indices] = 0