from triacontagon.batch import Batcher from triacontagon.data import Data from triacontagon.decode import dedicom_decoder import torch def test_batcher_01(): d = Data() d.add_vertex_type('Gene', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 1), (0, 3), (1, 4), (2, 0), (3, 2), (4, 3) } def test_batcher_02(): d = Data() d.add_vertex_type('Gene', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [1, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = (t.relation_type_index,) + \ tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 0, 1), (0, 0, 3), (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), (1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), (1, 3, 1), (1, 4, 2) } def test_batcher_03(): d = Data() d.add_vertex_type('Gene', 5) d.add_vertex_type('Drug', 4) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [0, 1, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0] ]).to_sparse(), torch.tensor([ [1, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0], [1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0] ]).to_sparse() ], dedicom_decoder) b = Batcher(d, batch_size=1) visited = set() for t in b: print(t) k = (t.vertex_type_row, t.vertex_type_column, t.relation_type_index,) + \ tuple(t.edges[0].tolist()) visited.add(k) assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), (0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), (0, 1, 0, 4, 2) }