|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- from icosagon.fastloop import FastBatcher, \
- FastModel
- from icosagon.data import Data
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- import torch
-
-
- def test_fast_batcher_01():
- d = Data()
- d.add_node_type('Gene', 5)
- d.add_node_type('Drug', 3)
-
- fam = d.add_relation_family('Gene-Drug', 0, 1, True)
-
- adj_mat = torch.tensor([
- [ 1, 0, 1 ],
- [ 0, 0, 1 ],
- [ 0, 1, 0 ],
- [ 1, 0, 0 ],
- [ 0, 1, 1 ]
- ], dtype=torch.float32).to_sparse()
- fam.add_relation_type('Target', adj_mat)
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
- # print(prep_d.relation_families[0])
-
- g = torch.Generator()
- batcher = FastBatcher(prep_d, batch_size=3, shuffle=True,
- generator=g, part_type='train')
-
- print(batcher.edges)
- print(batcher.targets)
-
- edges_check = [ set() for _ in range(len(batcher.edges)) ]
-
- for fam_idx, edges, targets in batcher:
- print(fam_idx, edges, targets)
- for e in edges:
- edges_check[fam_idx].add(tuple(e.tolist()))
-
- edges_check_2 = [ set() for _ in range(len(batcher.edges)) ]
- for i, edges in enumerate(batcher.edges):
- for e in edges:
- edges_check_2[i].add(tuple(e.tolist()))
-
- assert edges_check == edges_check_2
-
-
- def test_fast_model_01():
- raise NotImplementedError
|