IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Introduce fixed_unigram_candidate_sampler_new().

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
346f99fa83
4 muutettua tiedostoa jossa 107 lisäystä ja 23 poistoa
  1. +17
    -9
      src/triacontagon/cumcount.py
  2. +67
    -2
      src/triacontagon/sampling.py
  3. +12
    -11
      tests/triacontagon/test_cumcount.py
  4. +11
    -1
      tests/triacontagon/test_sampling.py

+ 17
- 9
src/triacontagon/cumcount.py Näytä tiedosto

@@ -1,22 +1,30 @@
import torch
import numpy as np
def dfill(a):
n = a.size
b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]])
return np.arange(n)[b[:-1]].repeat(np.diff(b))
n = torch.numel(a)
b = torch.cat([
torch.tensor([0]),
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
torch.tensor([n])
])
res = torch.arange(n)[b[:-1]]
res = torch.repeat_interleave(res, b[1:] - b[:-1])
return res
def argunsort(s):
n = s.size
u = np.empty(n, dtype=np.int64)
u[s] = np.arange(n)
n = torch.numel(s)
u = torch.empty(n, dtype=torch.int64)
u[s] = torch.arange(n)
return u
def cumcount(a):
n = a.size
s = a.argsort(kind='mergesort')
n = torch.numel(a)
s = np.argsort(a.detach().cpu().numpy())
s = torch.tensor(s, device=a.device)
i = argunsort(s)
b = a[s]
return (np.arange(n) - dfill(b))[i]
return (torch.arange(n) - dfill(b))[i]

+ 67
- 2
src/triacontagon/sampling.py Näytä tiedosto

@@ -21,6 +21,71 @@ from itertools import product, \
from functools import reduce
def fixed_unigram_candidate_sampler_new(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
distortion: float = 1.) -> torch.Tensor:
assert isinstance(true_classes, torch.Tensor)
assert isinstance(num_repeats, torch.Tensor)
assert isinstance(unigrams, torch.Tensor)
distortion = float(distortion)
if len(true_classes.shape) != 2:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
if len(num_repeats.shape) != 1:
raise ValueError('num_repeats must be 1D')
if torch.any((unigrams > 0).sum() - \
(true_classes >= 0).sum(dim=1) < \
num_repeats):
raise ValueError('Not enough classes to choose from')
true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1)
true_classes = torch.cat([
true_classes,
torch.full(( len(true_classes), torch.max(num_repeats) ), -1,
dtype=true_classes.dtype)
], dim=1)
indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats)
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
indices.view(-1, 1) ], dim=1)
result = torch.zeros(len(indices), dtype=torch.long)
while len(indices) > 0:
candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = torch.tensor(list(candidates)).view(-1, 1)
inner_order = torch.argsort(candidates[:, 0])
indices_np = indices[inner_order].detach().cpu().numpy()
outer_order = np.argsort(indices_np[:, 1], kind='stable')
outer_order = torch.tensor(outer_order, device=inner_order.device)
candidates = candidates[inner_order][outer_order]
indices = indices[inner_order][outer_order]
mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool)
can_cum = cumcount(candidates[:, 0])
ind_cum = cumcount(indices[:, 1])
repeated = (can_cum > 0) & (ind_cum > 0)
mask = mask | repeated
updated = indices[~mask]
ofs = true_class_count[updated[:, 1]] + \
cumcount(updated[:, 1])
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
true_class_count[updated[:, 1]] = ofs + 1
result[indices[:, 0]] = candidates.transpose(0, 1)
indices = indices[mask]
def fixed_unigram_candidate_sampler_slow(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
@@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
# indices = indices.copy()
# true_classes[indices[0], 0] = indices[1]
t = time.time()
cc = cumcount(indices[0].cpu().numpy())
cc = cumcount(indices[0])
print('cumcount() took:', time.time() - t)
cc = torch.tensor(cc)
# cc = torch.tensor(cc)
t = time.time()
true_classes[indices[0], cc] = indices[1]
print('assignment took:', time.time() - t)


+ 12
- 11
tests/triacontagon/test_cumcount.py Näytä tiedosto

@@ -1,26 +1,27 @@
from triacontagon.cumcount import dfill, \
argunsort, \
cumcount
import torch
import numpy as np
def test_dfill_01():
input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5])
input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5])
output = dfill(input)
expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12])
assert np.all(output == expected)
expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12])
assert torch.all(output == expected)
def test_argunsort_01():
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
output = np.argsort(input, kind='mergesort')
output = argunsort(output)
expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11])
assert np.all(output == expected)
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
output = np.argsort(input.numpy())
output = argunsort(torch.tensor(output))
expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11])
assert torch.all(output == expected)
def test_cumcount_01():
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
output = cumcount(input)
expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1])
assert np.all(output == expected)
expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1])
assert torch.all(output == expected)

+ 11
- 1
tests/triacontagon/test_sampling.py Näytä tiedosto

@@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \
get_true_classes, \
negative_sample_adj_mat, \
negative_sample_data, \
get_edges_and_degrees
get_edges_and_degrees, \
fixed_unigram_candidate_sampler_new
from triacontagon.decode import dedicom_decoder
import torch
import time
@@ -113,3 +114,12 @@ def test_negative_sample_data_01():
], dedicom_decoder)
d_neg = negative_sample_data(d)
def test_fixed_unigram_candidate_sampler_new_01():
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse()
true_classes, row_count = get_true_classes(x)
edges, degrees = get_edges_and_degrees(x)
# import pdb
# pdb.set_trace()
_ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75)

Loading…
Peruuta
Tallenna