IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

52 lines
1.4KB

  1. from icosagon.fastloop import FastBatcher, \
  2. FastModel
  3. from icosagon.data import Data
  4. from icosagon.trainprep import prepare_training, \
  5. TrainValTest
  6. import torch
  7. def test_fast_batcher_01():
  8. d = Data()
  9. d.add_node_type('Gene', 5)
  10. d.add_node_type('Drug', 3)
  11. fam = d.add_relation_family('Gene-Drug', 0, 1, True)
  12. adj_mat = torch.tensor([
  13. [ 1, 0, 1 ],
  14. [ 0, 0, 1 ],
  15. [ 0, 1, 0 ],
  16. [ 1, 0, 0 ],
  17. [ 0, 1, 1 ]
  18. ], dtype=torch.float32).to_sparse()
  19. fam.add_relation_type('Target', adj_mat)
  20. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  21. # print(prep_d.relation_families[0])
  22. g = torch.Generator()
  23. batcher = FastBatcher(prep_d, batch_size=3, shuffle=True,
  24. generator=g, part_type='train')
  25. print(batcher.edges)
  26. print(batcher.targets)
  27. edges_check = [ set() for _ in range(len(batcher.edges)) ]
  28. for fam_idx, edges, targets in batcher:
  29. print(fam_idx, edges, targets)
  30. for e in edges:
  31. edges_check[fam_idx].add(tuple(e.tolist()))
  32. edges_check_2 = [ set() for _ in range(len(batcher.edges)) ]
  33. for i, edges in enumerate(batcher.edges):
  34. for e in edges:
  35. edges_check_2[i].add(tuple(e.tolist()))
  36. assert edges_check == edges_check_2
  37. def test_fast_model_01():
  38. raise NotImplementedError