IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

49 lines
1.3KB

  1. import torch
  2. from triacontagon.model import Model
  3. from triacontagon.data import Data
  4. from triacontagon.decode import dedicom_decoder
  5. def test_model_convolve_01():
  6. d = Data()
  7. d.add_vertex_type('Gene', 4)
  8. d.add_vertex_type('Drug', 5)
  9. d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
  10. [1, 0, 0, 1],
  11. [0, 1, 1, 0],
  12. [0, 0, 1, 0],
  13. [0, 1, 0, 1]
  14. ], dtype=torch.float).to_sparse() ], dedicom_decoder)
  15. d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
  16. [0, 1, 0, 0, 1],
  17. [0, 0, 1, 0, 0],
  18. [1, 0, 0, 0, 1],
  19. [0, 0, 1, 1, 0]
  20. ], dtype=torch.float).to_sparse() ], dedicom_decoder)
  21. d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
  22. [1, 0, 0, 0, 0],
  23. [0, 1, 0, 0, 0],
  24. [0, 0, 1, 0, 0],
  25. [0, 0, 0, 1, 0],
  26. [0, 0, 0, 0, 1]
  27. ], dtype=torch.float).to_sparse() ], dedicom_decoder)
  28. model = Model(d, [9, 32, 64], keep_prob=1.0,
  29. conv_activation = torch.sigmoid,
  30. dec_activation = torch.sigmoid)
  31. repr_1 = torch.eye(9)
  32. repr_1[4:, 4:] = 0
  33. repr_2 = torch.eye(9)
  34. repr_2[:4, :4] = 0
  35. in_layer_repr = [
  36. repr_1[:4, :].to_sparse(),
  37. repr_2[4:, :].to_sparse()
  38. ]
  39. _ = model.convolve(in_layer_repr)