from triacontagon.split import split_adj_mat from triacontagon.util import _equal import torch def test_split_adj_mat_01(): adj_mat = torch.tensor([ [0, 1, 0, 0, 1], [0, 0, 1, 0, 1], [1, 0, 0, 1, 0], [0, 0, 1, 1, 0] ]).to_sparse() (res,) = split_adj_mat(adj_mat, (1.,)) assert torch.all(_equal(res, adj_mat)) def test_split_adj_mat_02(): adj_mat = torch.tensor([ [0, 1, 0, 0, 1], [0, 0, 1, 0, 1], [1, 0, 0, 1, 0], [0, 0, 1, 1, 0] ]).to_sparse() a, b = split_adj_mat(adj_mat, ( .5, .5 )) assert torch.all(_equal(a+b, adj_mat)) def test_split_adj_mat_03(): adj_mat = torch.tensor([ [0, 1, 0, 0, 1], [0, 0, 1, 0, 1], [1, 0, 0, 1, 0], [0, 0, 1, 1, 0] ]).to_sparse() a, b, c = split_adj_mat(adj_mat, ( .8, .1, .1 )) print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense()) assert torch.all(_equal(a+b+c, adj_mat))