From 670237c3f8bc6796d65cc83155ebfe474393b157 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 21 Aug 2020 17:15:58 +0200 Subject: [PATCH] Work on fixed_unigram_candidate_sampler_new(). --- src/triacontagon/cumcount.py | 1 + src/triacontagon/sampling.py | 6 +++--- tests/triacontagon/test_sampling.py | 10 +++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py index faee2be..5c88da6 100644 --- a/src/triacontagon/cumcount.py +++ b/src/triacontagon/cumcount.py @@ -9,6 +9,7 @@ def dfill(a): torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, torch.tensor([n]) ]) + print('b:',b) res = torch.arange(n)[b[:-1]] res = torch.repeat_interleave(res, b[1:] - b[:-1]) return res diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 76e42dc..d43afd0 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -21,7 +21,7 @@ from itertools import product, \ from functools import reduce -def fixed_unigram_candidate_sampler_new( +def fixed_unigram_candidate_sampler( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, @@ -50,7 +50,7 @@ def fixed_unigram_candidate_sampler_new( dtype=true_classes.dtype) ], dim=1) - indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) + indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats) indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), indices.view(-1, 1) ], dim=1) @@ -135,7 +135,7 @@ def fixed_unigram_candidate_sampler_slow( return torch.tensor(res) -def fixed_unigram_candidate_sampler( +def fixed_unigram_candidate_sampler_old( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 618f172..d5480fb 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -3,11 +3,12 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ get_true_classes, \ negative_sample_adj_mat, \ negative_sample_data, \ - get_edges_and_degrees, \ - fixed_unigram_candidate_sampler_new + get_edges_and_degrees +import triacontagon.sampling from triacontagon.decode import dedicom_decoder import torch import time +import pytest def test_fixed_unigram_candidate_sampler_01(): @@ -41,6 +42,7 @@ def test_fixed_unigram_candidate_sampler_02(): print('row_count:', row_count) edges_pos, degrees = get_edges_and_degrees(bar_foo) + print('degrees:', degrees) res = fixed_unigram_candidate_sampler(true_classes, row_count, degrees, 0.75) @@ -117,7 +119,9 @@ def test_negative_sample_data_01(): def test_fixed_unigram_candidate_sampler_new_01(): - x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() + if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling): + pytest.skip('fixed_unigram_candidate_sampler_new not found') + x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse() true_classes, row_count = get_true_classes(x) edges, degrees = get_edges_and_degrees(x) # import pdb