@@ -1,22 +1,30 @@ | |||||
import torch | |||||
import numpy as np | import numpy as np | ||||
def dfill(a): | def dfill(a): | ||||
n = a.size | |||||
b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||||
return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||||
n = torch.numel(a) | |||||
b = torch.cat([ | |||||
torch.tensor([0]), | |||||
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | |||||
torch.tensor([n]) | |||||
]) | |||||
res = torch.arange(n)[b[:-1]] | |||||
res = torch.repeat_interleave(res, b[1:] - b[:-1]) | |||||
return res | |||||
def argunsort(s): | def argunsort(s): | ||||
n = s.size | |||||
u = np.empty(n, dtype=np.int64) | |||||
u[s] = np.arange(n) | |||||
n = torch.numel(s) | |||||
u = torch.empty(n, dtype=torch.int64) | |||||
u[s] = torch.arange(n) | |||||
return u | return u | ||||
def cumcount(a): | def cumcount(a): | ||||
n = a.size | |||||
s = a.argsort(kind='mergesort') | |||||
n = torch.numel(a) | |||||
s = np.argsort(a.detach().cpu().numpy()) | |||||
s = torch.tensor(s, device=a.device) | |||||
i = argunsort(s) | i = argunsort(s) | ||||
b = a[s] | b = a[s] | ||||
return (np.arange(n) - dfill(b))[i] | |||||
return (torch.arange(n) - dfill(b))[i] |
@@ -21,6 +21,71 @@ from itertools import product, \ | |||||
from functools import reduce | from functools import reduce | ||||
def fixed_unigram_candidate_sampler_new( | |||||
true_classes: torch.Tensor, | |||||
num_repeats: torch.Tensor, | |||||
unigrams: torch.Tensor, | |||||
distortion: float = 1.) -> torch.Tensor: | |||||
assert isinstance(true_classes, torch.Tensor) | |||||
assert isinstance(num_repeats, torch.Tensor) | |||||
assert isinstance(unigrams, torch.Tensor) | |||||
distortion = float(distortion) | |||||
if len(true_classes.shape) != 2: | |||||
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||||
if len(num_repeats.shape) != 1: | |||||
raise ValueError('num_repeats must be 1D') | |||||
if torch.any((unigrams > 0).sum() - \ | |||||
(true_classes >= 0).sum(dim=1) < \ | |||||
num_repeats): | |||||
raise ValueError('Not enough classes to choose from') | |||||
true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) | |||||
true_classes = torch.cat([ | |||||
true_classes, | |||||
torch.full(( len(true_classes), torch.max(num_repeats) ), -1, | |||||
dtype=true_classes.dtype) | |||||
], dim=1) | |||||
indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) | |||||
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||||
indices.view(-1, 1) ], dim=1) | |||||
result = torch.zeros(len(indices), dtype=torch.long) | |||||
while len(indices) > 0: | |||||
candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||||
candidates = torch.tensor(list(candidates)).view(-1, 1) | |||||
inner_order = torch.argsort(candidates[:, 0]) | |||||
indices_np = indices[inner_order].detach().cpu().numpy() | |||||
outer_order = np.argsort(indices_np[:, 1], kind='stable') | |||||
outer_order = torch.tensor(outer_order, device=inner_order.device) | |||||
candidates = candidates[inner_order][outer_order] | |||||
indices = indices[inner_order][outer_order] | |||||
mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) | |||||
can_cum = cumcount(candidates[:, 0]) | |||||
ind_cum = cumcount(indices[:, 1]) | |||||
repeated = (can_cum > 0) & (ind_cum > 0) | |||||
mask = mask | repeated | |||||
updated = indices[~mask] | |||||
ofs = true_class_count[updated[:, 1]] + \ | |||||
cumcount(updated[:, 1]) | |||||
true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) | |||||
true_class_count[updated[:, 1]] = ofs + 1 | |||||
result[indices[:, 0]] = candidates.transpose(0, 1) | |||||
indices = indices[mask] | |||||
def fixed_unigram_candidate_sampler_slow( | def fixed_unigram_candidate_sampler_slow( | ||||
true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
@@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||||
# indices = indices.copy() | # indices = indices.copy() | ||||
# true_classes[indices[0], 0] = indices[1] | # true_classes[indices[0], 0] = indices[1] | ||||
t = time.time() | t = time.time() | ||||
cc = cumcount(indices[0].cpu().numpy()) | |||||
cc = cumcount(indices[0]) | |||||
print('cumcount() took:', time.time() - t) | print('cumcount() took:', time.time() - t) | ||||
cc = torch.tensor(cc) | |||||
# cc = torch.tensor(cc) | |||||
t = time.time() | t = time.time() | ||||
true_classes[indices[0], cc] = indices[1] | true_classes[indices[0], cc] = indices[1] | ||||
print('assignment took:', time.time() - t) | print('assignment took:', time.time() - t) | ||||
@@ -1,26 +1,27 @@ | |||||
from triacontagon.cumcount import dfill, \ | from triacontagon.cumcount import dfill, \ | ||||
argunsort, \ | argunsort, \ | ||||
cumcount | cumcount | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
def test_dfill_01(): | def test_dfill_01(): | ||||
input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||||
input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||||
output = dfill(input) | output = dfill(input) | ||||
expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||||
assert np.all(output == expected) | |||||
expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||||
assert torch.all(output == expected) | |||||
def test_argunsort_01(): | def test_argunsort_01(): | ||||
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
output = np.argsort(input, kind='mergesort') | |||||
output = argunsort(output) | |||||
expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||||
assert np.all(output == expected) | |||||
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
output = np.argsort(input.numpy()) | |||||
output = argunsort(torch.tensor(output)) | |||||
expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||||
assert torch.all(output == expected) | |||||
def test_cumcount_01(): | def test_cumcount_01(): | ||||
input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
output = cumcount(input) | output = cumcount(input) | ||||
expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||||
assert np.all(output == expected) | |||||
expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||||
assert torch.all(output == expected) |
@@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||||
get_true_classes, \ | get_true_classes, \ | ||||
negative_sample_adj_mat, \ | negative_sample_adj_mat, \ | ||||
negative_sample_data, \ | negative_sample_data, \ | ||||
get_edges_and_degrees | |||||
get_edges_and_degrees, \ | |||||
fixed_unigram_candidate_sampler_new | |||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
import torch | import torch | ||||
import time | import time | ||||
@@ -113,3 +114,12 @@ def test_negative_sample_data_01(): | |||||
], dedicom_decoder) | ], dedicom_decoder) | ||||
d_neg = negative_sample_data(d) | d_neg = negative_sample_data(d) | ||||
def test_fixed_unigram_candidate_sampler_new_01(): | |||||
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() | |||||
true_classes, row_count = get_true_classes(x) | |||||
edges, degrees = get_edges_and_degrees(x) | |||||
# import pdb | |||||
# pdb.set_trace() | |||||
_ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75) |