diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py index ba1f23f..faee2be 100644 --- a/src/triacontagon/cumcount.py +++ b/src/triacontagon/cumcount.py @@ -1,22 +1,30 @@ +import torch import numpy as np def dfill(a): - n = a.size - b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) - return np.arange(n)[b[:-1]].repeat(np.diff(b)) + n = torch.numel(a) + b = torch.cat([ + torch.tensor([0]), + torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, + torch.tensor([n]) + ]) + res = torch.arange(n)[b[:-1]] + res = torch.repeat_interleave(res, b[1:] - b[:-1]) + return res def argunsort(s): - n = s.size - u = np.empty(n, dtype=np.int64) - u[s] = np.arange(n) + n = torch.numel(s) + u = torch.empty(n, dtype=torch.int64) + u[s] = torch.arange(n) return u def cumcount(a): - n = a.size - s = a.argsort(kind='mergesort') + n = torch.numel(a) + s = np.argsort(a.detach().cpu().numpy()) + s = torch.tensor(s, device=a.device) i = argunsort(s) b = a[s] - return (np.arange(n) - dfill(b))[i] + return (torch.arange(n) - dfill(b))[i] diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index ab033dd..76e42dc 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -21,6 +21,71 @@ from itertools import product, \ from functools import reduce +def fixed_unigram_candidate_sampler_new( + true_classes: torch.Tensor, + num_repeats: torch.Tensor, + unigrams: torch.Tensor, + distortion: float = 1.) -> torch.Tensor: + + assert isinstance(true_classes, torch.Tensor) + assert isinstance(num_repeats, torch.Tensor) + assert isinstance(unigrams, torch.Tensor) + distortion = float(distortion) + + if len(true_classes.shape) != 2: + raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') + + if len(num_repeats.shape) != 1: + raise ValueError('num_repeats must be 1D') + + if torch.any((unigrams > 0).sum() - \ + (true_classes >= 0).sum(dim=1) < \ + num_repeats): + raise ValueError('Not enough classes to choose from') + + true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) + true_classes = torch.cat([ + true_classes, + torch.full(( len(true_classes), torch.max(num_repeats) ), -1, + dtype=true_classes.dtype) + ], dim=1) + + indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) + indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), + indices.view(-1, 1) ], dim=1) + + result = torch.zeros(len(indices), dtype=torch.long) + + while len(indices) > 0: + candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) + candidates = torch.tensor(list(candidates)).view(-1, 1) + + inner_order = torch.argsort(candidates[:, 0]) + indices_np = indices[inner_order].detach().cpu().numpy() + outer_order = np.argsort(indices_np[:, 1], kind='stable') + outer_order = torch.tensor(outer_order, device=inner_order.device) + + candidates = candidates[inner_order][outer_order] + indices = indices[inner_order][outer_order] + + mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) + + can_cum = cumcount(candidates[:, 0]) + ind_cum = cumcount(indices[:, 1]) + repeated = (can_cum > 0) & (ind_cum > 0) + + mask = mask | repeated + + updated = indices[~mask] + ofs = true_class_count[updated[:, 1]] + \ + cumcount(updated[:, 1]) + true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) + true_class_count[updated[:, 1]] = ofs + 1 + + result[indices[:, 0]] = candidates.transpose(0, 1) + indices = indices[mask] + + def fixed_unigram_candidate_sampler_slow( true_classes: torch.Tensor, num_repeats: torch.Tensor, @@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # indices = indices.copy() # true_classes[indices[0], 0] = indices[1] t = time.time() - cc = cumcount(indices[0].cpu().numpy()) + cc = cumcount(indices[0]) print('cumcount() took:', time.time() - t) - cc = torch.tensor(cc) + # cc = torch.tensor(cc) t = time.time() true_classes[indices[0], cc] = indices[1] print('assignment took:', time.time() - t) diff --git a/tests/triacontagon/test_cumcount.py b/tests/triacontagon/test_cumcount.py index b9a8780..694b46c 100644 --- a/tests/triacontagon/test_cumcount.py +++ b/tests/triacontagon/test_cumcount.py @@ -1,26 +1,27 @@ from triacontagon.cumcount import dfill, \ argunsort, \ cumcount +import torch import numpy as np def test_dfill_01(): - input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) + input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) output = dfill(input) - expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) - assert np.all(output == expected) + expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) + assert torch.all(output == expected) def test_argunsort_01(): - input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) - output = np.argsort(input, kind='mergesort') - output = argunsort(output) - expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) - assert np.all(output == expected) + input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) + output = np.argsort(input.numpy()) + output = argunsort(torch.tensor(output)) + expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) + assert torch.all(output == expected) def test_cumcount_01(): - input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) + input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) output = cumcount(input) - expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) - assert np.all(output == expected) + expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) + assert torch.all(output == expected) diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 088cee3..618f172 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ get_true_classes, \ negative_sample_adj_mat, \ negative_sample_data, \ - get_edges_and_degrees + get_edges_and_degrees, \ + fixed_unigram_candidate_sampler_new from triacontagon.decode import dedicom_decoder import torch import time @@ -113,3 +114,12 @@ def test_negative_sample_data_01(): ], dedicom_decoder) d_neg = negative_sample_data(d) + + +def test_fixed_unigram_candidate_sampler_new_01(): + x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() + true_classes, row_count = get_true_classes(x) + edges, degrees = get_edges_and_degrees(x) + # import pdb + # pdb.set_trace() + _ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75)