from icosagon.fastloop import FastBatcher, \ FastModel from icosagon.data import Data from icosagon.trainprep import prepare_training, \ TrainValTest import torch def test_fast_batcher_01(): d = Data() d.add_node_type('Gene', 5) d.add_node_type('Drug', 3) fam = d.add_relation_family('Gene-Drug', 0, 1, True) adj_mat = torch.tensor([ [ 1, 0, 1 ], [ 0, 0, 1 ], [ 0, 1, 0 ], [ 1, 0, 0 ], [ 0, 1, 1 ] ], dtype=torch.float32).to_sparse() fam.add_relation_type('Target', adj_mat) prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) # print(prep_d.relation_families[0]) g = torch.Generator() batcher = FastBatcher(prep_d, batch_size=3, shuffle=True, generator=g, part_type='train') print(batcher.edges) print(batcher.targets) edges_check = [ set() for _ in range(len(batcher.edges)) ] for fam_idx, edges, targets in batcher: print(fam_idx, edges, targets) for e in edges: edges_check[fam_idx].add(tuple(e.tolist())) edges_check_2 = [ set() for _ in range(len(batcher.edges)) ] for i, edges in enumerate(batcher.edges): for e in edges: edges_check_2[i].add(tuple(e.tolist())) assert edges_check == edges_check_2 def test_fast_model_01(): raise NotImplementedError