|
- from triacontagon.model import _per_layer_required_rows, \
- TrainingBatch
- from triacontagon.decode import dedicom_decoder
- from triacontagon.data import Data
- import torch
-
-
- def test_per_layer_required_rows_01():
- d = Data()
- d.add_vertex_type('Gene', 4)
- d.add_vertex_type('Drug', 5)
-
- d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
- [1, 0, 0, 1],
- [0, 1, 1, 0],
- [0, 0, 1, 0],
- [0, 1, 0, 1]
- ]).to_sparse() ], dedicom_decoder)
-
- d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
- [0, 1, 0, 0, 1],
- [0, 0, 1, 0, 0],
- [1, 0, 0, 0, 1],
- [0, 0, 1, 1, 0]
- ]).to_sparse() ], dedicom_decoder)
-
- d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
- [1, 0, 0, 0, 0],
- [0, 1, 0, 0, 0],
- [0, 0, 1, 0, 0],
- [0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1]
- ]).to_sparse() ], dedicom_decoder)
-
- batch = TrainingBatch(0, 1, 0, torch.tensor([
- [0, 1]
- ]))
-
- res = _per_layer_required_rows(d, batch, 5)
- print('res:', res)
|