@@ -9,6 +9,7 @@ def dfill(a): | |||||
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1, | ||||
torch.tensor([n]) | torch.tensor([n]) | ||||
]) | ]) | ||||
print('b:',b) | |||||
res = torch.arange(n)[b[:-1]] | res = torch.arange(n)[b[:-1]] | ||||
res = torch.repeat_interleave(res, b[1:] - b[:-1]) | res = torch.repeat_interleave(res, b[1:] - b[:-1]) | ||||
return res | return res | ||||
@@ -21,7 +21,7 @@ from itertools import product, \ | |||||
from functools import reduce | from functools import reduce | ||||
def fixed_unigram_candidate_sampler_new( | |||||
def fixed_unigram_candidate_sampler( | |||||
true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||
@@ -50,7 +50,7 @@ def fixed_unigram_candidate_sampler_new( | |||||
dtype=true_classes.dtype) | dtype=true_classes.dtype) | ||||
], dim=1) | ], dim=1) | ||||
indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats) | |||||
indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats) | |||||
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | ||||
indices.view(-1, 1) ], dim=1) | indices.view(-1, 1) ], dim=1) | ||||
@@ -135,7 +135,7 @@ def fixed_unigram_candidate_sampler_slow( | |||||
return torch.tensor(res) | return torch.tensor(res) | ||||
def fixed_unigram_candidate_sampler( | |||||
def fixed_unigram_candidate_sampler_old( | |||||
true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||
@@ -3,11 +3,12 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||||
get_true_classes, \ | get_true_classes, \ | ||||
negative_sample_adj_mat, \ | negative_sample_adj_mat, \ | ||||
negative_sample_data, \ | negative_sample_data, \ | ||||
get_edges_and_degrees, \ | |||||
fixed_unigram_candidate_sampler_new | |||||
get_edges_and_degrees | |||||
import triacontagon.sampling | |||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
import torch | import torch | ||||
import time | import time | ||||
import pytest | |||||
def test_fixed_unigram_candidate_sampler_01(): | def test_fixed_unigram_candidate_sampler_01(): | ||||
@@ -41,6 +42,7 @@ def test_fixed_unigram_candidate_sampler_02(): | |||||
print('row_count:', row_count) | print('row_count:', row_count) | ||||
edges_pos, degrees = get_edges_and_degrees(bar_foo) | edges_pos, degrees = get_edges_and_degrees(bar_foo) | ||||
print('degrees:', degrees) | |||||
res = fixed_unigram_candidate_sampler(true_classes, row_count, | res = fixed_unigram_candidate_sampler(true_classes, row_count, | ||||
degrees, 0.75) | degrees, 0.75) | ||||
@@ -117,7 +119,9 @@ def test_negative_sample_data_01(): | |||||
def test_fixed_unigram_candidate_sampler_new_01(): | def test_fixed_unigram_candidate_sampler_new_01(): | ||||
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse() | |||||
if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling): | |||||
pytest.skip('fixed_unigram_candidate_sampler_new not found') | |||||
x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse() | |||||
true_classes, row_count = get_true_classes(x) | true_classes, row_count = get_true_classes(x) | ||||
edges, degrees = get_edges_and_degrees(x) | edges, degrees = get_edges_and_degrees(x) | ||||
# import pdb | # import pdb | ||||