diff --git a/tests/triacontagon/test_loop.py b/tests/triacontagon/test_loop.py index 7456ab5..dcc975a 100644 --- a/tests/triacontagon/test_loop.py +++ b/tests/triacontagon/test_loop.py @@ -83,7 +83,7 @@ def test_train_loop_01(): [1, 0, 0, 1, 0], [0, 0, 1, 0, 1], [0, 1, 0, 0, 0] - ]) + ], dtype=torch.float32) foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 foo_bar = torch.tensor([ @@ -92,7 +92,7 @@ def test_train_loop_01(): [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1] - ]) + ], dtype=torch.float32) bar_foo = foo_bar.transpose(0, 1) bar_bar = torch.tensor([ @@ -100,7 +100,7 @@ def test_train_loop_01(): [1, 0, 0, 0], [0, 1, 0, 1], [0, 1, 0, 0], - ]) + ], dtype=torch.float32) bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 data.add_edge_type('Foo-Foo', 0, 0, [ diff --git a/tests/triacontagon/test_split.py b/tests/triacontagon/test_split.py index 85ddc9d..dc5459b 100644 --- a/tests/triacontagon/test_split.py +++ b/tests/triacontagon/test_split.py @@ -1,7 +1,10 @@ from triacontagon.split import split_adj_mat, \ - split_edge_type + split_edge_type, \ + split_data from triacontagon.util import _equal -from triacontagon.data import EdgeType +from triacontagon.data import EdgeType, \ + Data +from triacontagon.decode import dedicom_decoder import torch @@ -122,3 +125,125 @@ def test_split_edge_type_04(): res[0].adjacency_matrices[1] + \ res[1].adjacency_matrices[1] + \ res[2].adjacency_matrices[1])) + + +def test_split_data_01(): + data = Data() + data.add_vertex_type('Foo', 5) + data.add_vertex_type('Bar', 4) + + foo_foo = torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 1, 0, 0, 1], + [0, 1, 0, 0, 0], + [1, 0, 0, 1, 0] + ], dtype=torch.float32) + foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 + + foo_bar = torch.tensor([ + [0, 1, 0, 1], + [0, 0, 0, 1], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 1] + ], dtype=torch.float32) + bar_foo = foo_bar.transpose(0, 1) + + bar_bar = torch.tensor([ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 1], + [0, 1, 0, 0], + ], dtype=torch.float32) + bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 + + data.add_edge_type('Foo-Foo', 0, 0, [ + foo_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Foo-Bar', 0, 1, [ + foo_bar.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Foo', 1, 0, [ + bar_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Bar', 1, 1, [ + bar_bar.to_sparse().coalesce() + ], dedicom_decoder) + + (res,) = split_data(data, (1.,)) + + assert torch.all(_equal(res.edge_types[0, 0].adjacency_matrices[0], + data.edge_types[0, 0].adjacency_matrices[0])) + + assert torch.all(_equal(res.edge_types[0, 1].adjacency_matrices[0], + data.edge_types[0, 1].adjacency_matrices[0])) + + assert torch.all(_equal(res.edge_types[1, 0].adjacency_matrices[0], + data.edge_types[1, 0].adjacency_matrices[0])) + + assert torch.all(_equal(res.edge_types[1, 1].adjacency_matrices[0], + data.edge_types[1, 1].adjacency_matrices[0])) + + +def test_split_data_02(): + data = Data() + data.add_vertex_type('Foo', 5) + data.add_vertex_type('Bar', 4) + + foo_foo = torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 1, 0, 0, 1], + [0, 1, 0, 0, 0], + [1, 0, 0, 1, 0] + ], dtype=torch.float32) + foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 + + foo_bar = torch.tensor([ + [0, 1, 0, 1], + [0, 0, 0, 1], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 1] + ], dtype=torch.float32) + bar_foo = foo_bar.transpose(0, 1) + + bar_bar = torch.tensor([ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 1], + [0, 1, 0, 0], + ], dtype=torch.float32) + bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 + + data.add_edge_type('Foo-Foo', 0, 0, [ + foo_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Foo-Bar', 0, 1, [ + foo_bar.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Foo', 1, 0, [ + bar_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Bar', 1, 1, [ + bar_bar.to_sparse().coalesce() + ], dedicom_decoder) + + a, b = split_data(data, (.5,.5)) + + assert torch.all(_equal(a.edge_types[0, 0].adjacency_matrices[0] + \ + b.edge_types[0, 0].adjacency_matrices[0], + data.edge_types[0, 0].adjacency_matrices[0])) + + assert torch.all(_equal(a.edge_types[0, 1].adjacency_matrices[0] + \ + b.edge_types[0, 1].adjacency_matrices[0], + data.edge_types[0, 1].adjacency_matrices[0])) + + assert torch.all(_equal(a.edge_types[1, 0].adjacency_matrices[0] + \ + b.edge_types[1, 0].adjacency_matrices[0], + data.edge_types[1, 0].adjacency_matrices[0])) + + assert torch.all(_equal(a.edge_types[1, 1].adjacency_matrices[0] + \ + b.edge_types[1, 1].adjacency_matrices[0], + data.edge_types[1, 1].adjacency_matrices[0]))