IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.9KB

  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.data import Data
  4. import torch
  5. import pytest
  6. def _some_data():
  7. d = Data()
  8. d.add_node_type('Gene', 1000)
  9. d.add_node_type('Drug', 100)
  10. d.add_relation_type('Target', 1, 0, torch.rand(100, 1000))
  11. d.add_relation_type('Interaction', 0, 0, torch.rand(1000, 1000))
  12. d.add_relation_type('Side Effect: Nausea', 1, 1, torch.rand(100, 100))
  13. d.add_relation_type('Side Effect: Infertility', 1, 1, torch.rand(100, 100))
  14. d.add_relation_type('Side Effect: Death', 1, 1, torch.rand(100, 100))
  15. return d
  16. def _some_data_with_interactions():
  17. d = Data()
  18. d.add_node_type('Gene', 1000)
  19. d.add_node_type('Drug', 100)
  20. d.add_relation_type('Target', 1, 0,
  21. torch.rand((100, 1000), dtype=torch.float32).round())
  22. d.add_relation_type('Interaction', 0, 0,
  23. torch.rand((1000, 1000), dtype=torch.float32).round())
  24. d.add_relation_type('Side Effect: Nausea', 1, 1,
  25. torch.rand((100, 100), dtype=torch.float32).round())
  26. d.add_relation_type('Side Effect: Infertility', 1, 1,
  27. torch.rand((100, 100), dtype=torch.float32).round())
  28. d.add_relation_type('Side Effect: Death', 1, 1,
  29. torch.rand((100, 100), dtype=torch.float32).round())
  30. return d
  31. def test_input_layer_01():
  32. d = _some_data()
  33. for output_dim in [32, 64, 128]:
  34. layer = InputLayer(d, output_dim)
  35. assert layer.output_dim[0] == output_dim
  36. assert len(layer.node_reps) == 2
  37. assert layer.node_reps[0].shape == (1000, output_dim)
  38. assert layer.node_reps[1].shape == (100, output_dim)
  39. assert layer.data == d
  40. def test_input_layer_02():
  41. d = _some_data()
  42. layer = InputLayer(d, 32)
  43. res = layer(None)
  44. assert isinstance(res[0], torch.Tensor)
  45. assert isinstance(res[1], torch.Tensor)
  46. assert res[0].shape == (1000, 32)
  47. assert res[1].shape == (100, 32)
  48. assert torch.all(res[0] == layer.node_reps[0])
  49. assert torch.all(res[1] == layer.node_reps[1])
  50. def test_input_layer_03():
  51. if torch.cuda.device_count() == 0:
  52. pytest.skip('No CUDA devices on this host')
  53. d = _some_data()
  54. layer = InputLayer(d, 32)
  55. device = torch.device('cuda:0')
  56. layer = layer.to(device)
  57. print(list(layer.parameters()))
  58. # assert layer.device.type == 'cuda:0'
  59. assert layer.node_reps[0].device == device
  60. assert layer.node_reps[1].device == device
  61. def test_input_layer_04():
  62. d = _some_data()
  63. layer = InputLayer(d, 32)
  64. s = repr(layer)
  65. assert s.startswith('Icosagon input layer')
  66. def test_one_hot_input_layer_01():
  67. d = _some_data()
  68. layer = OneHotInputLayer(d)
  69. assert layer.output_dim == [1000, 100]
  70. assert len(layer.node_reps) == 2
  71. assert layer.node_reps[0].shape == (1000, 1000)
  72. assert layer.node_reps[1].shape == (100, 100)
  73. assert layer.data == d
  74. assert layer.is_sparse
  75. def test_one_hot_input_layer_02():
  76. d = _some_data()
  77. layer = OneHotInputLayer(d)
  78. res = layer(None)
  79. assert isinstance(res[0], torch.Tensor)
  80. assert isinstance(res[1], torch.Tensor)
  81. assert res[0].shape == (1000, 1000)
  82. assert res[1].shape == (100, 100)
  83. assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
  84. assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
  85. def test_one_hot_input_layer_03():
  86. if torch.cuda.device_count() == 0:
  87. pytest.skip('No CUDA devices on this host')
  88. d = _some_data()
  89. layer = OneHotInputLayer(d)
  90. device = torch.device('cuda:0')
  91. layer = layer.to(device)
  92. print(list(layer.parameters()))
  93. # assert layer.device.type == 'cuda:0'
  94. assert layer.node_reps[0].device == device
  95. assert layer.node_reps[1].device == device
  96. def test_one_hot_input_layer_04():
  97. d = _some_data()
  98. layer = OneHotInputLayer(d)
  99. s = repr(layer)
  100. assert s.startswith('Icosagon one-hot input layer')