diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 7f79719..58b7ba0 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -17,41 +17,41 @@ import time def fixed_unigram_candidate_sampler( - true_classes: Union[np.array, torch.Tensor], - unigrams: List[Union[int, float]], + true_classes: torch.Tensor, + num_repeats: torch.Tensor, + unigrams: torch.Tensor, distortion: float = 1.): - if isinstance(true_classes, torch.Tensor): - true_classes = true_classes.detach().cpu().numpy() - - if isinstance(unigrams, torch.Tensor): - unigrams = unigrams.detach().cpu().numpy() - if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') - num_samples = true_classes.shape[0] - unigrams = np.array(unigrams) + if len(num_repeats.shape) != 1: + raise ValueError('num_repeats must be 1D') + + num_rows = true_classes.shape[0] + # unigrams = np.array(unigrams) if distortion != 1.: - unigrams = unigrams.astype(np.float64) ** distortion + unigrams = unigrams.to(torch.float64) ** distortion # print('unigrams:', unigrams) - indices = np.arange(num_samples) - result = np.zeros(num_samples, dtype=np.int64) + indices = torch.arange(num_rows) + indices = torch.repeat_interleave(indices, num_repeats) + num_samples = len(indices) + result = torch.zeros(num_samples, dtype=torch.long) while len(indices) > 0: # print('len(indices):', len(indices)) sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) - candidates = np.array(list(sampler)) - candidates = np.reshape(candidates, (len(indices), 1)) + candidates = torch.tensor(list(sampler)) + candidates = candidates.view(len(indices), 1) # print('candidates:', candidates) # print('true_classes:', true_classes[indices, :]) - result[indices] = candidates.T + result[indices] = candidates.transpose(0, 1) # print('result:', result) mask = (candidates == true_classes[indices, :]) - mask = mask.sum(1).astype(np.bool) + mask = mask.sum(1).to(torch.bool) # print('mask:', mask) indices = indices[mask] # result[indices] = 0 - return torch.tensor(result) + return result def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ @@ -71,7 +71,7 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ return edges_pos, degrees -def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: +def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: indices = adj_mat.indices() row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) #print('indices[0]:', indices[0], count[indices[0]]) @@ -105,11 +105,11 @@ def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: true_classes[row, count[row]] = col count[row] += 1 ''' - t = time.time() - true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) - print('repeat_interleave() took:', time.time() - t) + # t = time.time() + # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) + # print('repeat_interleave() took:', time.time() - t) - return true_classes + return true_classes, row_count def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: @@ -118,12 +118,12 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: edges_pos, degrees = get_edges_and_degrees(adj_mat) - true_classes = get_true_classes(adj_mat) + true_classes, row_count = get_true_classes(adj_mat) # true_classes = edges_pos[:, 1].view(-1, 1) # print('true_classes:', true_classes) neg_neighbors = fixed_unigram_candidate_sampler( - true_classes, degrees, 0.75).to(adj_mat.device) + true_classes, row_count, degrees, 0.75).to(adj_mat.device) print('neg_neighbors:', neg_neighbors) diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 454d3c1..6bba237 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -16,9 +16,11 @@ def test_get_true_classes_01(): [0, 1, 0, 0, 0] ], dtype=torch.float).to_sparse() - true_classes = get_true_classes(adj_mat) + true_classes, row_count = get_true_classes(adj_mat) print('true_classes:', true_classes) + true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) + assert torch.all(true_classes == torch.tensor([ [1, 3], [1, 3], @@ -32,10 +34,10 @@ def test_get_true_classes_01(): def test_get_true_classes_02(): - adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse() + adj_mat = torch.rand(2000, 2000).round().to_sparse() t = time.time() - true_classes = get_true_classes(adj_mat) + true_classes, row_count = get_true_classes(adj_mat) print('Elapsed:', time.time() - t) print('true_classes.shape:', true_classes.shape)