IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

41 lines
1.1KB

  1. from triacontagon.model import _per_layer_required_rows, \
  2. TrainingBatch
  3. from triacontagon.decode import dedicom_decoder
  4. from triacontagon.data import Data
  5. import torch
  6. def test_per_layer_required_rows_01():
  7. d = Data()
  8. d.add_vertex_type('Gene', 4)
  9. d.add_vertex_type('Drug', 5)
  10. d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
  11. [1, 0, 0, 1],
  12. [0, 1, 1, 0],
  13. [0, 0, 1, 0],
  14. [0, 1, 0, 1]
  15. ]).to_sparse() ], dedicom_decoder)
  16. d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
  17. [0, 1, 0, 0, 1],
  18. [0, 0, 1, 0, 0],
  19. [1, 0, 0, 0, 1],
  20. [0, 0, 1, 1, 0]
  21. ]).to_sparse() ], dedicom_decoder)
  22. d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
  23. [1, 0, 0, 0, 0],
  24. [0, 1, 0, 0, 0],
  25. [0, 0, 1, 0, 0],
  26. [0, 0, 0, 1, 0],
  27. [0, 0, 0, 0, 1]
  28. ]).to_sparse() ], dedicom_decoder)
  29. batch = TrainingBatch(0, 1, 0, torch.tensor([
  30. [0, 1]
  31. ]))
  32. res = _per_layer_required_rows(d, batch, 5)
  33. print('res:', res)