@@ -1,22 +1,30 @@ | |||
import torch | |||
import numpy as np | |||
def dfill(a): | |||
n = a.size | |||
b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||
return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||
n = torch.numel(a) | |||
b = torch.cat([ | |||
torch.tensor([0]), | |||
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | |||
torch.tensor([n]) | |||
]) | |||
res = torch.arange(n)[b[:-1]] | |||
res = torch.repeat_interleave(res, b[1:] - b[:-1]) | |||
return res | |||
def argunsort(s): | |||
n = s.size | |||
u = np.empty(n, dtype=np.int64) | |||
u[s] = np.arange(n) | |||
n = torch.numel(s) | |||
u = torch.empty(n, dtype=torch.int64) | |||
u[s] = torch.arange(n) | |||
return u | |||
def cumcount(a): | |||
n = a.size | |||
s = a.argsort(kind='mergesort') | |||
n = torch.numel(a) | |||
s = np.argsort(a.detach().cpu().numpy()) | |||
s = torch.tensor(s, device=a.device) | |||
i = argunsort(s) | |||
b = a[s] | |||
return (np.arange(n) - dfill(b))[i] | |||
return (torch.arange(n) - dfill(b))[i] |
@@ -21,6 +21,71 @@ from itertools import product, \ | |||
from functools import reduce | |||
def fixed_unigram_candidate_sampler_new( | |||
true_classes: torch.Tensor, | |||
num_repeats: torch.Tensor, | |||
unigrams: torch.Tensor, | |||
distortion: float = 1.) -> torch.Tensor: | |||
assert isinstance(true_classes, torch.Tensor) | |||
assert isinstance(num_repeats, torch.Tensor) | |||
assert isinstance(unigrams, torch.Tensor) | |||
distortion = float(distortion) | |||
if len(true_classes.shape) != 2: | |||
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||
if len(num_repeats.shape) != 1: | |||
raise ValueError('num_repeats must be 1D') | |||
if torch.any((unigrams > 0).sum() - \ | |||
(true_classes >= 0).sum(dim=1) < \ | |||
num_repeats): | |||
raise ValueError('Not enough classes to choose from') | |||
true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) | |||
true_classes = torch.cat([ | |||
true_classes, | |||
torch.full(( len(true_classes), torch.max(num_repeats) ), -1, | |||
dtype=true_classes.dtype) | |||
], dim=1) | |||
indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) | |||
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||
indices.view(-1, 1) ], dim=1) | |||
result = torch.zeros(len(indices), dtype=torch.long) | |||
while len(indices) > 0: | |||
candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||
candidates = torch.tensor(list(candidates)).view(-1, 1) | |||
inner_order = torch.argsort(candidates[:, 0]) | |||
indices_np = indices[inner_order].detach().cpu().numpy() | |||
outer_order = np.argsort(indices_np[:, 1], kind='stable') | |||
outer_order = torch.tensor(outer_order, device=inner_order.device) | |||
candidates = candidates[inner_order][outer_order] | |||
indices = indices[inner_order][outer_order] | |||
mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) | |||
can_cum = cumcount(candidates[:, 0]) | |||
ind_cum = cumcount(indices[:, 1]) | |||
repeated = (can_cum > 0) & (ind_cum > 0) | |||
mask = mask | repeated | |||
updated = indices[~mask] | |||
ofs = true_class_count[updated[:, 1]] + \ | |||
cumcount(updated[:, 1]) | |||
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) | |||
true_class_count[updated[:, 1]] = ofs + 1 | |||
result[indices[:, 0]] = candidates.transpose(0, 1) | |||
indices = indices[mask] | |||
def fixed_unigram_candidate_sampler_slow( | |||
true_classes: torch.Tensor, | |||
num_repeats: torch.Tensor, | |||
@@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||
# indices = indices.copy() | |||
# true_classes[indices[0], 0] = indices[1] | |||
t = time.time() | |||
cc = cumcount(indices[0].cpu().numpy()) | |||
cc = cumcount(indices[0]) | |||
print('cumcount() took:', time.time() - t) | |||
cc = torch.tensor(cc) | |||
# cc = torch.tensor(cc) | |||
t = time.time() | |||
true_classes[indices[0], cc] = indices[1] | |||
print('assignment took:', time.time() - t) | |||
@@ -1,26 +1,27 @@ | |||
from triacontagon.cumcount import dfill, \ | |||
argunsort, \ | |||
cumcount | |||
import torch | |||
import numpy as np | |||
def test_dfill_01(): | |||
input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||
input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||
output = dfill(input) | |||
expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||
assert np.all(output == expected) | |||
expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||
assert torch.all(output == expected) | |||
def test_argunsort_01(): | |||
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
output = np.argsort(input, kind='mergesort') | |||
output = argunsort(output) | |||
expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||
assert np.all(output == expected) | |||
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
output = np.argsort(input.numpy()) | |||
output = argunsort(torch.tensor(output)) | |||
expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||
assert torch.all(output == expected) | |||
def test_cumcount_01(): | |||
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
output = cumcount(input) | |||
expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||
assert np.all(output == expected) | |||
expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||
assert torch.all(output == expected) |
@@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||
get_true_classes, \ | |||
negative_sample_adj_mat, \ | |||
negative_sample_data, \ | |||
get_edges_and_degrees | |||
get_edges_and_degrees, \ | |||
fixed_unigram_candidate_sampler_new | |||
from triacontagon.decode import dedicom_decoder | |||
import torch | |||
import time | |||
@@ -113,3 +114,12 @@ def test_negative_sample_data_01(): | |||
], dedicom_decoder) | |||
d_neg = negative_sample_data(d) | |||
def test_fixed_unigram_candidate_sampler_new_01(): | |||
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() | |||
true_classes, row_count = get_true_classes(x) | |||
edges, degrees = get_edges_and_degrees(x) | |||
# import pdb | |||
# pdb.set_trace() | |||
_ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75) |