from triacontagon.model import _per_layer_required_rows, \ TrainingBatch from triacontagon.decode import dedicom_decoder from triacontagon.data import Data import torch def test_per_layer_required_rows_01(): d = Data() d.add_vertex_type('Gene', 4) d.add_vertex_type('Drug', 5) d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 1] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0] ]).to_sparse() ], dedicom_decoder) d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1] ]).to_sparse() ], dedicom_decoder) batch = TrainingBatch(0, 1, 0, torch.tensor([ [0, 1] ])) res = _per_layer_required_rows(d, batch, 5) print('res:', res)