From 8fbad74dfa2fec703345904c76a4a08bdc35a493 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 20 Aug 2020 18:56:52 +0200 Subject: [PATCH] Debug negative sampling. --- tests/triacontagon/test_sampling.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index e72f7a8..088cee3 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -2,7 +2,8 @@ from triacontagon.data import Data from triacontagon.sampling import fixed_unigram_candidate_sampler, \ get_true_classes, \ negative_sample_adj_mat, \ - negative_sample_data + negative_sample_data, \ + get_edges_and_degrees from triacontagon.decode import dedicom_decoder import torch import time @@ -22,6 +23,29 @@ def test_fixed_unigram_candidate_sampler_01(): print('res:', res) +def test_fixed_unigram_candidate_sampler_02(): + foo_bar = torch.tensor([ + [0, 1, 0, 1], + [0, 0, 0, 1], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 1] + ], dtype=torch.float32) + + # bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce() + bar_foo = foo_bar.to_sparse().coalesce() + + true_classes, row_count = get_true_classes(bar_foo) + print('true_classes:', true_classes) + print('row_count:', row_count) + + edges_pos, degrees = get_edges_and_degrees(bar_foo) + + res = fixed_unigram_candidate_sampler(true_classes, row_count, + degrees, 0.75) + print('res:', res) + + def test_get_true_classes_01(): adj_mat = torch.tensor([ [0, 1, 0, 1, 0],