diff --git a/src/triacontagon/split.py b/src/triacontagon/split.py index 68826f1..77f5eaf 100644 --- a/src/triacontagon/split.py +++ b/src/triacontagon/split.py @@ -7,6 +7,10 @@ import torch def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): + ratios = list(ratios) + if sum(ratios) != 1: + raise ValueError('Sum of ratios must be 1') + indices = adj_mat.indices() values = adj_mat.values() @@ -33,6 +37,10 @@ def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): + ratios = list(ratios) + if sum(ratios) != 1: + raise ValueError('Sum of ratios must be 1') + res = [ split_adj_mat(adj_mat, ratios) \ for adj_mat in et.adjacency_matrices ] diff --git a/tests/triacontagon/test_split.py b/tests/triacontagon/test_split.py index 2cfc8e7..85ddc9d 100644 --- a/tests/triacontagon/test_split.py +++ b/tests/triacontagon/test_split.py @@ -1,5 +1,7 @@ -from triacontagon.split import split_adj_mat +from triacontagon.split import split_adj_mat, \ + split_edge_type from triacontagon.util import _equal +from triacontagon.data import EdgeType import torch @@ -39,3 +41,84 @@ def test_split_adj_mat_03(): print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense()) assert torch.all(_equal(a+b+c, adj_mat)) + + +def test_split_edge_type_01(): + et = EdgeType('Dummy', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [1, 0, 0, 0, 1], + [0, 1, 0, 1, 0] + ]).to_sparse() + ], None, None) + + res = split_edge_type(et, (1.,)) + + assert torch.all(_equal(et.adjacency_matrices[0], + res[0].adjacency_matrices[0])) + + +def test_split_edge_type_02(): + et = EdgeType('Dummy', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [1, 0, 0, 0, 1], + [0, 1, 0, 1, 0] + ]).to_sparse() + ], None, None) + + res = split_edge_type(et, (.5, .5)) + + assert torch.all(_equal(et.adjacency_matrices[0], + res[0].adjacency_matrices[0] + \ + res[1].adjacency_matrices[0])) + + +def test_split_edge_type_03(): + et = EdgeType('Dummy', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [1, 0, 0, 0, 1], + [0, 1, 0, 1, 0] + ]).to_sparse() + ], None, None) + + res = split_edge_type(et, (.4, .4, .2)) + + assert torch.all(_equal(et.adjacency_matrices[0], + res[0].adjacency_matrices[0] + \ + res[1].adjacency_matrices[0] + \ + res[2].adjacency_matrices[0])) + + +def test_split_edge_type_04(): + et = EdgeType('Dummy', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [1, 0, 0, 0, 1], + [0, 1, 0, 1, 0] + ]).to_sparse(), + + torch.tensor([ + [1, 0, 0, 0, 0], + [0, 1, 0, 1, 0], + [0, 0, 1, 1, 0], + [1, 0, 1, 0, 0] + ]).to_sparse() + ], None, None) + + res = split_edge_type(et, (.4, .4, .2)) + + assert torch.all(_equal(et.adjacency_matrices[0], + res[0].adjacency_matrices[0] + \ + res[1].adjacency_matrices[0] + \ + res[2].adjacency_matrices[0])) + + assert torch.all(_equal(et.adjacency_matrices[1], + res[0].adjacency_matrices[1] + \ + res[1].adjacency_matrices[1] + \ + res[2].adjacency_matrices[1]))