from triacontagon.batch import _same_data_org, \ DualBatcher, \ Batcher from triacontagon.data import Data from triacontagon.decode import dedicom_decoder import torch def test_same_data_org_01(): data = Data() assert _same_data_org(data, data) data.add_vertex_type('Foo', 10) assert _same_data_org(data, data) data.add_vertex_type('Bar', 10) assert _same_data_org(data, data) data_1 = Data() assert not _same_data_org(data, data_1) data_1.add_vertex_type('Foo', 10) assert not _same_data_org(data, data_1) data_1.add_vertex_type('Bar', 10) assert _same_data_org(data, data_1) def test_same_data_org_02(): data = Data() data.add_vertex_type('Foo', 4) data.add_edge_type('Foo-Foo', 0, 0, [ torch.tensor([ [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 1], [1, 0, 1, 0] ]).to_sparse() ], dedicom_decoder) assert _same_data_org(data, data) data_1 = Data() data_1.add_vertex_type('Foo', 4) data_1.add_edge_type('Foo-Foo', 0, 0, [ torch.tensor([ [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 1], [1, 0, 0, 0] ]).to_sparse() ], dedicom_decoder) assert not _same_data_org(data, data_1) def test_batcher_01(): d = Data() d.add_vertex_type('Gene', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 1), (0, 3), (1, 4), (2, 0), (3, 2), (4, 3) } def test_batcher_02(): d = Data() d.add_vertex_type('Gene', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [0, 0, 1, 0, 1], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = (t.relation_type_index,) + \ tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 0, 1), (0, 0, 3), (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), (1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4), (1, 3, 1), (1, 4, 2) } def test_batcher_03(): d = Data() d.add_vertex_type('Gene', 5) d.add_vertex_type('Drug', 4) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [0, 0, 1, 0, 1], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0], [1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = (t.vertex_type_row, t.vertex_type_column, t.relation_type_index,) + \ tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), (0, 1, 0, 4, 2) } def test_batcher_04(): d = Data() d.add_vertex_type('Gene', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=3) visited = set() for t in b: print(t) for e in t.edges: k = tuple(e.tolist()) visited.add(k) assert visited == { (0, 1), (0, 3), (1, 4), (2, 0), (3, 2), (4, 3) } def test_batcher_05(): d = Data() d.add_vertex_type('Gene', 5) d.add_vertex_type('Drug', 4) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [0, 0, 1, 0, 1], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0], [1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=5) visited = set() for t in b: print(t) for e in t.edges: k = (t.vertex_type_row, t.vertex_type_column, t.relation_type_index,) + \ tuple(e.tolist()) visited.add(k) assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), (0, 1, 0, 4, 2) } def test_dual_batcher_01(): d = Data() d.add_vertex_type('Gene', 5) d.add_vertex_type('Drug', 4) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [0, 0, 1, 0, 1], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0], [1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0] ]).to_sparse() ], dedicom_decoder) b = DualBatcher(d, d, batch_size=5) visited_pos = set() visited_neg = set() for t_pos, t_neg in b: assert t_pos.vertex_type_row == t_neg.vertex_type_row assert t_pos.vertex_type_column == t_neg.vertex_type_column assert t_pos.relation_type_index == t_neg.relation_type_index assert len(t_pos.edges) == len(t_neg.edges) for e in t_pos.edges: k = (t_pos.vertex_type_row, t_pos.vertex_type_column, t_pos.relation_type_index,) + \ tuple(e.tolist()) visited_pos.add(k) for e in t_neg.edges: k = (t_neg.vertex_type_row, t_neg.vertex_type_column, t_neg.relation_type_index,) + \ tuple(e.tolist()) visited_neg.add(k) expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), (0, 1, 0, 4, 2) } assert visited_pos == expected assert visited_neg == expected