|
|
@@ -3,11 +3,12 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \ |
|
|
|
get_true_classes, \
|
|
|
|
negative_sample_adj_mat, \
|
|
|
|
negative_sample_data, \
|
|
|
|
get_edges_and_degrees, \
|
|
|
|
fixed_unigram_candidate_sampler_new
|
|
|
|
get_edges_and_degrees
|
|
|
|
import triacontagon.sampling
|
|
|
|
from triacontagon.decode import dedicom_decoder
|
|
|
|
import torch
|
|
|
|
import time
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
def test_fixed_unigram_candidate_sampler_01():
|
|
|
@@ -41,6 +42,7 @@ def test_fixed_unigram_candidate_sampler_02(): |
|
|
|
print('row_count:', row_count)
|
|
|
|
|
|
|
|
edges_pos, degrees = get_edges_and_degrees(bar_foo)
|
|
|
|
print('degrees:', degrees)
|
|
|
|
|
|
|
|
res = fixed_unigram_candidate_sampler(true_classes, row_count,
|
|
|
|
degrees, 0.75)
|
|
|
@@ -117,7 +119,9 @@ def test_negative_sample_data_01(): |
|
|
|
|
|
|
|
|
|
|
|
def test_fixed_unigram_candidate_sampler_new_01():
|
|
|
|
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse()
|
|
|
|
if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling):
|
|
|
|
pytest.skip('fixed_unigram_candidate_sampler_new not found')
|
|
|
|
x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse()
|
|
|
|
true_classes, row_count = get_true_classes(x)
|
|
|
|
edges, degrees = get_edges_and_degrees(x)
|
|
|
|
# import pdb
|
|
|
|