|  | from triacontagon.model import _per_layer_required_rows, \
    TrainingBatch
from triacontagon.decode import dedicom_decoder
from triacontagon.data import Data
import torch
def test_per_layer_required_rows_01():
    d = Data()
    d.add_vertex_type('Gene', 4)
    d.add_vertex_type('Drug', 5)
    d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
        [1, 0, 0, 1],
        [0, 1, 1, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 1]
    ]).to_sparse() ], dedicom_decoder)
    d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
        [0, 1, 0, 0, 1],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 1],
        [0, 0, 1, 1, 0]
    ]).to_sparse() ], dedicom_decoder)
    d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1]
    ]).to_sparse() ], dedicom_decoder)
    batch = TrainingBatch(0, 1, 0, torch.tensor([
        [0, 1]
    ]))
    res = _per_layer_required_rows(d, batch, 5)
    print('res:', res)
 |