|                                                   | 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 | from icosagon.fastloop import FastBatcher, \
    FastModel
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
    TrainValTest
import torch
def test_fast_batcher_01():
    d = Data()
    d.add_node_type('Gene', 5)
    d.add_node_type('Drug', 3)
    fam = d.add_relation_family('Gene-Drug', 0, 1, True)
    adj_mat = torch.tensor([
        [ 1, 0, 1 ],
        [ 0, 0, 1 ],
        [ 0, 1, 0 ],
        [ 1, 0, 0 ],
        [ 0, 1, 1 ]
    ], dtype=torch.float32).to_sparse()
    fam.add_relation_type('Target', adj_mat)
    prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
    # print(prep_d.relation_families[0])
    g = torch.Generator()
    batcher = FastBatcher(prep_d, batch_size=3, shuffle=True,
        generator=g, part_type='train')
    print(batcher.edges)
    print(batcher.targets)
    edges_check = [ set() for _ in range(len(batcher.edges)) ]
    for fam_idx, edges, targets in batcher:
        print(fam_idx, edges, targets)
        for e in edges:
            edges_check[fam_idx].add(tuple(e.tolist()))
    edges_check_2 = [ set() for _ in range(len(batcher.edges)) ]
    for i, edges in enumerate(batcher.edges):
        for e in edges:
            edges_check_2[i].add(tuple(e.tolist()))
    assert edges_check == edges_check_2
def test_fast_model_01():
    raise NotImplementedError
 |