|
|
@@ -2,7 +2,8 @@ from triacontagon.data import Data |
|
|
|
from triacontagon.sampling import fixed_unigram_candidate_sampler, \
|
|
|
|
get_true_classes, \
|
|
|
|
negative_sample_adj_mat, \
|
|
|
|
negative_sample_data
|
|
|
|
negative_sample_data, \
|
|
|
|
get_edges_and_degrees
|
|
|
|
from triacontagon.decode import dedicom_decoder
|
|
|
|
import torch
|
|
|
|
import time
|
|
|
@@ -22,6 +23,29 @@ def test_fixed_unigram_candidate_sampler_01(): |
|
|
|
print('res:', res)
|
|
|
|
|
|
|
|
|
|
|
|
def test_fixed_unigram_candidate_sampler_02():
|
|
|
|
foo_bar = torch.tensor([
|
|
|
|
[0, 1, 0, 1],
|
|
|
|
[0, 0, 0, 1],
|
|
|
|
[0, 1, 0, 0],
|
|
|
|
[1, 0, 0, 0],
|
|
|
|
[0, 0, 1, 1]
|
|
|
|
], dtype=torch.float32)
|
|
|
|
|
|
|
|
# bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce()
|
|
|
|
bar_foo = foo_bar.to_sparse().coalesce()
|
|
|
|
|
|
|
|
true_classes, row_count = get_true_classes(bar_foo)
|
|
|
|
print('true_classes:', true_classes)
|
|
|
|
print('row_count:', row_count)
|
|
|
|
|
|
|
|
edges_pos, degrees = get_edges_and_degrees(bar_foo)
|
|
|
|
|
|
|
|
res = fixed_unigram_candidate_sampler(true_classes, row_count,
|
|
|
|
degrees, 0.75)
|
|
|
|
print('res:', res)
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_true_classes_01():
|
|
|
|
adj_mat = torch.tensor([
|
|
|
|
[0, 1, 0, 1, 0],
|
|
|
|