| @@ -1,22 +1,30 @@ | |||||
| import torch | |||||
| import numpy as np | import numpy as np | ||||
| def dfill(a): | def dfill(a): | ||||
| n = a.size | |||||
| b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||||
| return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||||
| n = torch.numel(a) | |||||
| b = torch.cat([ | |||||
| torch.tensor([0]), | |||||
| torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | |||||
| torch.tensor([n]) | |||||
| ]) | |||||
| res = torch.arange(n)[b[:-1]] | |||||
| res = torch.repeat_interleave(res, b[1:] - b[:-1]) | |||||
| return res | |||||
| def argunsort(s): | def argunsort(s): | ||||
| n = s.size | |||||
| u = np.empty(n, dtype=np.int64) | |||||
| u[s] = np.arange(n) | |||||
| n = torch.numel(s) | |||||
| u = torch.empty(n, dtype=torch.int64) | |||||
| u[s] = torch.arange(n) | |||||
| return u | return u | ||||
| def cumcount(a): | def cumcount(a): | ||||
| n = a.size | |||||
| s = a.argsort(kind='mergesort') | |||||
| n = torch.numel(a) | |||||
| s = np.argsort(a.detach().cpu().numpy()) | |||||
| s = torch.tensor(s, device=a.device) | |||||
| i = argunsort(s) | i = argunsort(s) | ||||
| b = a[s] | b = a[s] | ||||
| return (np.arange(n) - dfill(b))[i] | |||||
| return (torch.arange(n) - dfill(b))[i] | |||||
| @@ -21,6 +21,71 @@ from itertools import product, \ | |||||
| from functools import reduce | from functools import reduce | ||||
| def fixed_unigram_candidate_sampler_new( | |||||
| true_classes: torch.Tensor, | |||||
| num_repeats: torch.Tensor, | |||||
| unigrams: torch.Tensor, | |||||
| distortion: float = 1.) -> torch.Tensor: | |||||
| assert isinstance(true_classes, torch.Tensor) | |||||
| assert isinstance(num_repeats, torch.Tensor) | |||||
| assert isinstance(unigrams, torch.Tensor) | |||||
| distortion = float(distortion) | |||||
| if len(true_classes.shape) != 2: | |||||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||||
| if len(num_repeats.shape) != 1: | |||||
| raise ValueError('num_repeats must be 1D') | |||||
| if torch.any((unigrams > 0).sum() - \ | |||||
| (true_classes >= 0).sum(dim=1) < \ | |||||
| num_repeats): | |||||
| raise ValueError('Not enough classes to choose from') | |||||
| true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) | |||||
| true_classes = torch.cat([ | |||||
| true_classes, | |||||
| torch.full(( len(true_classes), torch.max(num_repeats) ), -1, | |||||
| dtype=true_classes.dtype) | |||||
| ], dim=1) | |||||
| indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) | |||||
| indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||||
| indices.view(-1, 1) ], dim=1) | |||||
| result = torch.zeros(len(indices), dtype=torch.long) | |||||
| while len(indices) > 0: | |||||
| candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||||
| candidates = torch.tensor(list(candidates)).view(-1, 1) | |||||
| inner_order = torch.argsort(candidates[:, 0]) | |||||
| indices_np = indices[inner_order].detach().cpu().numpy() | |||||
| outer_order = np.argsort(indices_np[:, 1], kind='stable') | |||||
| outer_order = torch.tensor(outer_order, device=inner_order.device) | |||||
| candidates = candidates[inner_order][outer_order] | |||||
| indices = indices[inner_order][outer_order] | |||||
| mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) | |||||
| can_cum = cumcount(candidates[:, 0]) | |||||
| ind_cum = cumcount(indices[:, 1]) | |||||
| repeated = (can_cum > 0) & (ind_cum > 0) | |||||
| mask = mask | repeated | |||||
| updated = indices[~mask] | |||||
| ofs = true_class_count[updated[:, 1]] + \ | |||||
| cumcount(updated[:, 1]) | |||||
| true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) | |||||
| true_class_count[updated[:, 1]] = ofs + 1 | |||||
| result[indices[:, 0]] = candidates.transpose(0, 1) | |||||
| indices = indices[mask] | |||||
| def fixed_unigram_candidate_sampler_slow( | def fixed_unigram_candidate_sampler_slow( | ||||
| true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
| num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
| @@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||||
| # indices = indices.copy() | # indices = indices.copy() | ||||
| # true_classes[indices[0], 0] = indices[1] | # true_classes[indices[0], 0] = indices[1] | ||||
| t = time.time() | t = time.time() | ||||
| cc = cumcount(indices[0].cpu().numpy()) | |||||
| cc = cumcount(indices[0]) | |||||
| print('cumcount() took:', time.time() - t) | print('cumcount() took:', time.time() - t) | ||||
| cc = torch.tensor(cc) | |||||
| # cc = torch.tensor(cc) | |||||
| t = time.time() | t = time.time() | ||||
| true_classes[indices[0], cc] = indices[1] | true_classes[indices[0], cc] = indices[1] | ||||
| print('assignment took:', time.time() - t) | print('assignment took:', time.time() - t) | ||||
| @@ -1,26 +1,27 @@ | |||||
| from triacontagon.cumcount import dfill, \ | from triacontagon.cumcount import dfill, \ | ||||
| argunsort, \ | argunsort, \ | ||||
| cumcount | cumcount | ||||
| import torch | |||||
| import numpy as np | import numpy as np | ||||
| def test_dfill_01(): | def test_dfill_01(): | ||||
| input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||||
| input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||||
| output = dfill(input) | output = dfill(input) | ||||
| expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||||
| assert np.all(output == expected) | |||||
| expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||||
| assert torch.all(output == expected) | |||||
| def test_argunsort_01(): | def test_argunsort_01(): | ||||
| input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
| output = np.argsort(input, kind='mergesort') | |||||
| output = argunsort(output) | |||||
| expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||||
| assert np.all(output == expected) | |||||
| input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
| output = np.argsort(input.numpy()) | |||||
| output = argunsort(torch.tensor(output)) | |||||
| expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||||
| assert torch.all(output == expected) | |||||
| def test_cumcount_01(): | def test_cumcount_01(): | ||||
| input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
| input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||||
| output = cumcount(input) | output = cumcount(input) | ||||
| expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||||
| assert np.all(output == expected) | |||||
| expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||||
| assert torch.all(output == expected) | |||||
| @@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||||
| get_true_classes, \ | get_true_classes, \ | ||||
| negative_sample_adj_mat, \ | negative_sample_adj_mat, \ | ||||
| negative_sample_data, \ | negative_sample_data, \ | ||||
| get_edges_and_degrees | |||||
| get_edges_and_degrees, \ | |||||
| fixed_unigram_candidate_sampler_new | |||||
| from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
| import torch | import torch | ||||
| import time | import time | ||||
| @@ -113,3 +114,12 @@ def test_negative_sample_data_01(): | |||||
| ], dedicom_decoder) | ], dedicom_decoder) | ||||
| d_neg = negative_sample_data(d) | d_neg = negative_sample_data(d) | ||||
| def test_fixed_unigram_candidate_sampler_new_01(): | |||||
| x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() | |||||
| true_classes, row_count = get_true_classes(x) | |||||
| edges, degrees = get_edges_and_degrees(x) | |||||
| # import pdb | |||||
| # pdb.set_trace() | |||||
| _ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75) | |||||