IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layer.py 5.2KB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from decagon_pytorch.layer import InputLayer, \
  2. OneHotInputLayer, \
  3. DecagonLayer
  4. from decagon_pytorch.data import Data
  5. import torch
  6. import pytest
  7. from decagon_pytorch.convolve import SparseDropoutGraphConvActivation
  8. def _some_data():
  9. d = Data()
  10. d.add_node_type('Gene', 1000)
  11. d.add_node_type('Drug', 100)
  12. d.add_relation_type('Target', 1, 0, None)
  13. d.add_relation_type('Interaction', 0, 0, None)
  14. d.add_relation_type('Side Effect: Nausea', 1, 1, None)
  15. d.add_relation_type('Side Effect: Infertility', 1, 1, None)
  16. d.add_relation_type('Side Effect: Death', 1, 1, None)
  17. return d
  18. def _some_data_with_interactions():
  19. d = Data()
  20. d.add_node_type('Gene', 1000)
  21. d.add_node_type('Drug', 100)
  22. d.add_relation_type('Target', 1, 0,
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. d.add_relation_type('Interaction', 0, 0,
  25. torch.rand((1000, 1000), dtype=torch.float32).round())
  26. d.add_relation_type('Side Effect: Nausea', 1, 1,
  27. torch.rand((100, 100), dtype=torch.float32).round())
  28. d.add_relation_type('Side Effect: Infertility', 1, 1,
  29. torch.rand((100, 100), dtype=torch.float32).round())
  30. d.add_relation_type('Side Effect: Death', 1, 1,
  31. torch.rand((100, 100), dtype=torch.float32).round())
  32. return d
  33. def test_input_layer_01():
  34. d = _some_data()
  35. for output_dim in [32, 64, 128]:
  36. layer = InputLayer(d, output_dim)
  37. assert layer.output_dim[0] == output_dim
  38. assert len(layer.node_reps) == 2
  39. assert layer.node_reps[0].shape == (1000, output_dim)
  40. assert layer.node_reps[1].shape == (100, output_dim)
  41. assert layer.data == d
  42. def test_input_layer_02():
  43. d = _some_data()
  44. layer = InputLayer(d, 32)
  45. res = layer()
  46. assert isinstance(res[0], torch.Tensor)
  47. assert isinstance(res[1], torch.Tensor)
  48. assert res[0].shape == (1000, 32)
  49. assert res[1].shape == (100, 32)
  50. assert torch.all(res[0] == layer.node_reps[0])
  51. assert torch.all(res[1] == layer.node_reps[1])
  52. def test_input_layer_03():
  53. if torch.cuda.device_count() == 0:
  54. pytest.skip('No CUDA devices on this host')
  55. d = _some_data()
  56. layer = InputLayer(d, 32)
  57. device = torch.device('cuda:0')
  58. layer = layer.to(device)
  59. print(list(layer.parameters()))
  60. # assert layer.device.type == 'cuda:0'
  61. assert layer.node_reps[0].device == device
  62. assert layer.node_reps[1].device == device
  63. def test_one_hot_input_layer_01():
  64. d = _some_data()
  65. layer = OneHotInputLayer(d)
  66. assert layer.output_dim == [1000, 100]
  67. assert len(layer.node_reps) == 2
  68. assert layer.node_reps[0].shape == (1000, 1000)
  69. assert layer.node_reps[1].shape == (100, 100)
  70. assert layer.data == d
  71. assert layer.is_sparse
  72. def test_one_hot_input_layer_02():
  73. d = _some_data()
  74. layer = OneHotInputLayer(d)
  75. res = layer()
  76. assert isinstance(res[0], torch.Tensor)
  77. assert isinstance(res[1], torch.Tensor)
  78. assert res[0].shape == (1000, 1000)
  79. assert res[1].shape == (100, 100)
  80. assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
  81. assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
  82. def test_one_hot_input_layer_03():
  83. if torch.cuda.device_count() == 0:
  84. pytest.skip('No CUDA devices on this host')
  85. d = _some_data()
  86. layer = OneHotInputLayer(d)
  87. device = torch.device('cuda:0')
  88. layer = layer.to(device)
  89. print(list(layer.parameters()))
  90. # assert layer.device.type == 'cuda:0'
  91. assert layer.node_reps[0].device == device
  92. assert layer.node_reps[1].device == device
  93. def test_decagon_layer_01():
  94. d = _some_data_with_interactions()
  95. in_layer = InputLayer(d)
  96. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  97. def test_decagon_layer_02():
  98. d = _some_data_with_interactions()
  99. in_layer = OneHotInputLayer(d)
  100. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  101. _ = d_layer() # dummy call
  102. def test_decagon_layer_03():
  103. d = _some_data_with_interactions()
  104. in_layer = OneHotInputLayer(d)
  105. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  106. assert d_layer.data == d
  107. assert d_layer.previous_layer == in_layer
  108. assert d_layer.input_dim == [ 1000, 100 ]
  109. assert not d_layer.is_sparse
  110. assert d_layer.keep_prob == 1.
  111. assert d_layer.rel_activation(0.5) == 0.5
  112. x = torch.tensor([-1, 0, 0.5, 1])
  113. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  114. assert len(d_layer.next_layer_repr) == 2
  115. assert len(d_layer.next_layer_repr[0]) == 2
  116. assert len(d_layer.next_layer_repr[1]) == 4
  117. assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
  118. d_layer.next_layer_repr[0]))
  119. assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
  120. d_layer.next_layer_repr[1]))
  121. assert all(map(lambda a: a[0].output_dim == 32,
  122. d_layer.next_layer_repr[0]))
  123. assert all(map(lambda a: a[0].output_dim == 32,
  124. d_layer.next_layer_repr[1]))
  125. def test_decagon_layer_04():
  126. d = _some_data_with_interactions()
  127. in_layer = OneHotInputLayer(d)
  128. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  129. _ = d_layer()