From 604b81675eae47ce7e89261f510a66d56946467b Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 20 Aug 2020 12:45:51 +0200 Subject: [PATCH] Add test_split_adj_mat_(01|02|03)(). --- tests/triacontagon/test_sampling.py | 1 - tests/triacontagon/test_split.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/triacontagon/test_split.py diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 0b45769..e72f7a8 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -22,7 +22,6 @@ def test_fixed_unigram_candidate_sampler_01(): print('res:', res) - def test_get_true_classes_01(): adj_mat = torch.tensor([ [0, 1, 0, 1, 0], diff --git a/tests/triacontagon/test_split.py b/tests/triacontagon/test_split.py new file mode 100644 index 0000000..2cfc8e7 --- /dev/null +++ b/tests/triacontagon/test_split.py @@ -0,0 +1,41 @@ +from triacontagon.split import split_adj_mat +from triacontagon.util import _equal +import torch + + +def test_split_adj_mat_01(): + adj_mat = torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + [0, 0, 1, 1, 0] + ]).to_sparse() + + (res,) = split_adj_mat(adj_mat, (1.,)) + assert torch.all(_equal(res, adj_mat)) + + +def test_split_adj_mat_02(): + adj_mat = torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + [0, 0, 1, 1, 0] + ]).to_sparse() + + a, b = split_adj_mat(adj_mat, ( .5, .5 )) + assert torch.all(_equal(a+b, adj_mat)) + + +def test_split_adj_mat_03(): + adj_mat = torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + [0, 0, 1, 1, 0] + ]).to_sparse() + + a, b, c = split_adj_mat(adj_mat, ( .8, .1, .1 )) + print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense()) + + assert torch.all(_equal(a+b+c, adj_mat))