IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
4.0KB

  1. from decagon_pytorch.layer import InputLayer, \
  2. OneHotInputLayer, \
  3. DecagonLayer
  4. from decagon_pytorch.data import Data
  5. import torch
  6. import pytest
  7. def _some_data():
  8. d = Data()
  9. d.add_node_type('Gene', 1000)
  10. d.add_node_type('Drug', 100)
  11. d.add_relation_type('Target', 1, 0, None)
  12. d.add_relation_type('Interaction', 0, 0, None)
  13. d.add_relation_type('Side Effect: Nausea', 1, 1, None)
  14. d.add_relation_type('Side Effect: Infertility', 1, 1, None)
  15. d.add_relation_type('Side Effect: Death', 1, 1, None)
  16. return d
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. d.add_relation_type('Target', 1, 0,
  22. torch.rand((100, 1000), dtype=torch.float32).round())
  23. d.add_relation_type('Interaction', 0, 0,
  24. torch.rand((1000, 1000), dtype=torch.float32).round())
  25. d.add_relation_type('Side Effect: Nausea', 1, 1,
  26. torch.rand((100, 100), dtype=torch.float32).round())
  27. d.add_relation_type('Side Effect: Infertility', 1, 1,
  28. torch.rand((100, 100), dtype=torch.float32).round())
  29. d.add_relation_type('Side Effect: Death', 1, 1,
  30. torch.rand((100, 100), dtype=torch.float32).round())
  31. return d
  32. def test_input_layer_01():
  33. d = _some_data()
  34. for output_dim in [32, 64, 128]:
  35. layer = InputLayer(d, output_dim)
  36. assert layer.output_dim[0] == output_dim
  37. assert len(layer.node_reps) == 2
  38. assert layer.node_reps[0].shape == (1000, output_dim)
  39. assert layer.node_reps[1].shape == (100, output_dim)
  40. assert layer.data == d
  41. def test_input_layer_02():
  42. d = _some_data()
  43. layer = InputLayer(d, 32)
  44. res = layer()
  45. assert isinstance(res[0], torch.Tensor)
  46. assert isinstance(res[1], torch.Tensor)
  47. assert res[0].shape == (1000, 32)
  48. assert res[1].shape == (100, 32)
  49. assert torch.all(res[0] == layer.node_reps[0])
  50. assert torch.all(res[1] == layer.node_reps[1])
  51. def test_input_layer_03():
  52. if torch.cuda.device_count() == 0:
  53. pytest.skip('No CUDA devices on this host')
  54. d = _some_data()
  55. layer = InputLayer(d, 32)
  56. device = torch.device('cuda:0')
  57. layer = layer.to(device)
  58. print(list(layer.parameters()))
  59. # assert layer.device.type == 'cuda:0'
  60. assert layer.node_reps[0].device == device
  61. assert layer.node_reps[1].device == device
  62. def test_one_hot_input_layer_01():
  63. d = _some_data()
  64. layer = OneHotInputLayer(d)
  65. assert layer.output_dim == [1000, 100]
  66. assert len(layer.node_reps) == 2
  67. assert layer.node_reps[0].shape == (1000, 1000)
  68. assert layer.node_reps[1].shape == (100, 100)
  69. assert layer.data == d
  70. assert layer.is_sparse
  71. def test_one_hot_input_layer_02():
  72. d = _some_data()
  73. layer = OneHotInputLayer(d)
  74. res = layer()
  75. assert isinstance(res[0], torch.Tensor)
  76. assert isinstance(res[1], torch.Tensor)
  77. assert res[0].shape == (1000, 1000)
  78. assert res[1].shape == (100, 100)
  79. assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
  80. assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
  81. def test_one_hot_input_layer_03():
  82. if torch.cuda.device_count() == 0:
  83. pytest.skip('No CUDA devices on this host')
  84. d = _some_data()
  85. layer = OneHotInputLayer(d)
  86. device = torch.device('cuda:0')
  87. layer = layer.to(device)
  88. print(list(layer.parameters()))
  89. # assert layer.device.type == 'cuda:0'
  90. assert layer.node_reps[0].device == device
  91. assert layer.node_reps[1].device == device
  92. def test_decagon_layer_01():
  93. d = _some_data_with_interactions()
  94. in_layer = InputLayer(d)
  95. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  96. def test_decagon_layer_02():
  97. d = _some_data_with_interactions()
  98. in_layer = OneHotInputLayer(d)
  99. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  100. _ = d_layer() # dummy call
  101. def test_decagon_layer_03():
  102. pass