diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index e72f7a8..088cee3 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -2,7 +2,8 @@ from triacontagon.data import Data from triacontagon.sampling import fixed_unigram_candidate_sampler, \ get_true_classes, \ negative_sample_adj_mat, \ - negative_sample_data + negative_sample_data, \ + get_edges_and_degrees from triacontagon.decode import dedicom_decoder import torch import time @@ -22,6 +23,29 @@ def test_fixed_unigram_candidate_sampler_01(): print('res:', res) +def test_fixed_unigram_candidate_sampler_02(): + foo_bar = torch.tensor([ + [0, 1, 0, 1], + [0, 0, 0, 1], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 1] + ], dtype=torch.float32) + + # bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce() + bar_foo = foo_bar.to_sparse().coalesce() + + true_classes, row_count = get_true_classes(bar_foo) + print('true_classes:', true_classes) + print('row_count:', row_count) + + edges_pos, degrees = get_edges_and_degrees(bar_foo) + + res = fixed_unigram_candidate_sampler(true_classes, row_count, + degrees, 0.75) + print('res:', res) + + def test_get_true_classes_01(): adj_mat = torch.tensor([ [0, 1, 0, 1, 0],