| @@ -1,22 +1,30 @@ | |||
| import torch | |||
| import numpy as np | |||
| def dfill(a): | |||
| n = a.size | |||
| b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||
| return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||
| n = torch.numel(a) | |||
| b = torch.cat([ | |||
| torch.tensor([0]), | |||
| torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | |||
| torch.tensor([n]) | |||
| ]) | |||
| res = torch.arange(n)[b[:-1]] | |||
| res = torch.repeat_interleave(res, b[1:] - b[:-1]) | |||
| return res | |||
| def argunsort(s): | |||
| n = s.size | |||
| u = np.empty(n, dtype=np.int64) | |||
| u[s] = np.arange(n) | |||
| n = torch.numel(s) | |||
| u = torch.empty(n, dtype=torch.int64) | |||
| u[s] = torch.arange(n) | |||
| return u | |||
| def cumcount(a): | |||
| n = a.size | |||
| s = a.argsort(kind='mergesort') | |||
| n = torch.numel(a) | |||
| s = np.argsort(a.detach().cpu().numpy()) | |||
| s = torch.tensor(s, device=a.device) | |||
| i = argunsort(s) | |||
| b = a[s] | |||
| return (np.arange(n) - dfill(b))[i] | |||
| return (torch.arange(n) - dfill(b))[i] | |||
| @@ -21,6 +21,71 @@ from itertools import product, \ | |||
| from functools import reduce | |||
| def fixed_unigram_candidate_sampler_new( | |||
| true_classes: torch.Tensor, | |||
| num_repeats: torch.Tensor, | |||
| unigrams: torch.Tensor, | |||
| distortion: float = 1.) -> torch.Tensor: | |||
| assert isinstance(true_classes, torch.Tensor) | |||
| assert isinstance(num_repeats, torch.Tensor) | |||
| assert isinstance(unigrams, torch.Tensor) | |||
| distortion = float(distortion) | |||
| if len(true_classes.shape) != 2: | |||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||
| if len(num_repeats.shape) != 1: | |||
| raise ValueError('num_repeats must be 1D') | |||
| if torch.any((unigrams > 0).sum() - \ | |||
| (true_classes >= 0).sum(dim=1) < \ | |||
| num_repeats): | |||
| raise ValueError('Not enough classes to choose from') | |||
| true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) | |||
| true_classes = torch.cat([ | |||
| true_classes, | |||
| torch.full(( len(true_classes), torch.max(num_repeats) ), -1, | |||
| dtype=true_classes.dtype) | |||
| ], dim=1) | |||
| indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) | |||
| indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||
| indices.view(-1, 1) ], dim=1) | |||
| result = torch.zeros(len(indices), dtype=torch.long) | |||
| while len(indices) > 0: | |||
| candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||
| candidates = torch.tensor(list(candidates)).view(-1, 1) | |||
| inner_order = torch.argsort(candidates[:, 0]) | |||
| indices_np = indices[inner_order].detach().cpu().numpy() | |||
| outer_order = np.argsort(indices_np[:, 1], kind='stable') | |||
| outer_order = torch.tensor(outer_order, device=inner_order.device) | |||
| candidates = candidates[inner_order][outer_order] | |||
| indices = indices[inner_order][outer_order] | |||
| mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) | |||
| can_cum = cumcount(candidates[:, 0]) | |||
| ind_cum = cumcount(indices[:, 1]) | |||
| repeated = (can_cum > 0) & (ind_cum > 0) | |||
| mask = mask | repeated | |||
| updated = indices[~mask] | |||
| ofs = true_class_count[updated[:, 1]] + \ | |||
| cumcount(updated[:, 1]) | |||
| true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) | |||
| true_class_count[updated[:, 1]] = ofs + 1 | |||
| result[indices[:, 0]] = candidates.transpose(0, 1) | |||
| indices = indices[mask] | |||
| def fixed_unigram_candidate_sampler_slow( | |||
| true_classes: torch.Tensor, | |||
| num_repeats: torch.Tensor, | |||
| @@ -162,9 +227,9 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||
| # indices = indices.copy() | |||
| # true_classes[indices[0], 0] = indices[1] | |||
| t = time.time() | |||
| cc = cumcount(indices[0].cpu().numpy()) | |||
| cc = cumcount(indices[0]) | |||
| print('cumcount() took:', time.time() - t) | |||
| cc = torch.tensor(cc) | |||
| # cc = torch.tensor(cc) | |||
| t = time.time() | |||
| true_classes[indices[0], cc] = indices[1] | |||
| print('assignment took:', time.time() - t) | |||
| @@ -1,26 +1,27 @@ | |||
| from triacontagon.cumcount import dfill, \ | |||
| argunsort, \ | |||
| cumcount | |||
| import torch | |||
| import numpy as np | |||
| def test_dfill_01(): | |||
| input = np.array([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||
| input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) | |||
| output = dfill(input) | |||
| expected = np.array([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||
| assert np.all(output == expected) | |||
| expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) | |||
| assert torch.all(output == expected) | |||
| def test_argunsort_01(): | |||
| input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
| output = np.argsort(input, kind='mergesort') | |||
| output = argunsort(output) | |||
| expected = np.array([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||
| assert np.all(output == expected) | |||
| input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
| output = np.argsort(input.numpy()) | |||
| output = argunsort(torch.tensor(output)) | |||
| expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) | |||
| assert torch.all(output == expected) | |||
| def test_cumcount_01(): | |||
| input = np.array([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
| input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) | |||
| output = cumcount(input) | |||
| expected = np.array([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||
| assert np.all(output == expected) | |||
| expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) | |||
| assert torch.all(output == expected) | |||
| @@ -3,7 +3,8 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||
| get_true_classes, \ | |||
| negative_sample_adj_mat, \ | |||
| negative_sample_data, \ | |||
| get_edges_and_degrees | |||
| get_edges_and_degrees, \ | |||
| fixed_unigram_candidate_sampler_new | |||
| from triacontagon.decode import dedicom_decoder | |||
| import torch | |||
| import time | |||
| @@ -113,3 +114,12 @@ def test_negative_sample_data_01(): | |||
| ], dedicom_decoder) | |||
| d_neg = negative_sample_data(d) | |||
| def test_fixed_unigram_candidate_sampler_new_01(): | |||
| x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() | |||
| true_classes, row_count = get_true_classes(x) | |||
| edges, degrees = get_edges_and_degrees(x) | |||
| # import pdb | |||
| # pdb.set_trace() | |||
| _ = fixed_unigram_candidate_sampler_new(true_classes, row_count, degrees, 0.75) | |||