IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

242 řádky
8.0KB

  1. from decagon_pytorch.layer import InputLayer, \
  2. OneHotInputLayer, \
  3. DecagonLayer
  4. from decagon_pytorch.data import Data
  5. import torch
  6. import pytest
  7. from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
  8. SparseMultiDGCA, \
  9. DropoutGraphConvActivation
  10. def _some_data():
  11. d = Data()
  12. d.add_node_type('Gene', 1000)
  13. d.add_node_type('Drug', 100)
  14. d.add_relation_type('Target', 1, 0, None)
  15. d.add_relation_type('Interaction', 0, 0, None)
  16. d.add_relation_type('Side Effect: Nausea', 1, 1, None)
  17. d.add_relation_type('Side Effect: Infertility', 1, 1, None)
  18. d.add_relation_type('Side Effect: Death', 1, 1, None)
  19. return d
  20. def _some_data_with_interactions():
  21. d = Data()
  22. d.add_node_type('Gene', 1000)
  23. d.add_node_type('Drug', 100)
  24. d.add_relation_type('Target', 1, 0,
  25. torch.rand((100, 1000), dtype=torch.float32).round())
  26. d.add_relation_type('Interaction', 0, 0,
  27. torch.rand((1000, 1000), dtype=torch.float32).round())
  28. d.add_relation_type('Side Effect: Nausea', 1, 1,
  29. torch.rand((100, 100), dtype=torch.float32).round())
  30. d.add_relation_type('Side Effect: Infertility', 1, 1,
  31. torch.rand((100, 100), dtype=torch.float32).round())
  32. d.add_relation_type('Side Effect: Death', 1, 1,
  33. torch.rand((100, 100), dtype=torch.float32).round())
  34. return d
  35. def test_input_layer_01():
  36. d = _some_data()
  37. for output_dim in [32, 64, 128]:
  38. layer = InputLayer(d, output_dim)
  39. assert layer.output_dim[0] == output_dim
  40. assert len(layer.node_reps) == 2
  41. assert layer.node_reps[0].shape == (1000, output_dim)
  42. assert layer.node_reps[1].shape == (100, output_dim)
  43. assert layer.data == d
  44. def test_input_layer_02():
  45. d = _some_data()
  46. layer = InputLayer(d, 32)
  47. res = layer()
  48. assert isinstance(res[0], torch.Tensor)
  49. assert isinstance(res[1], torch.Tensor)
  50. assert res[0].shape == (1000, 32)
  51. assert res[1].shape == (100, 32)
  52. assert torch.all(res[0] == layer.node_reps[0])
  53. assert torch.all(res[1] == layer.node_reps[1])
  54. def test_input_layer_03():
  55. if torch.cuda.device_count() == 0:
  56. pytest.skip('No CUDA devices on this host')
  57. d = _some_data()
  58. layer = InputLayer(d, 32)
  59. device = torch.device('cuda:0')
  60. layer = layer.to(device)
  61. print(list(layer.parameters()))
  62. # assert layer.device.type == 'cuda:0'
  63. assert layer.node_reps[0].device == device
  64. assert layer.node_reps[1].device == device
  65. def test_one_hot_input_layer_01():
  66. d = _some_data()
  67. layer = OneHotInputLayer(d)
  68. assert layer.output_dim == [1000, 100]
  69. assert len(layer.node_reps) == 2
  70. assert layer.node_reps[0].shape == (1000, 1000)
  71. assert layer.node_reps[1].shape == (100, 100)
  72. assert layer.data == d
  73. assert layer.is_sparse
  74. def test_one_hot_input_layer_02():
  75. d = _some_data()
  76. layer = OneHotInputLayer(d)
  77. res = layer()
  78. assert isinstance(res[0], torch.Tensor)
  79. assert isinstance(res[1], torch.Tensor)
  80. assert res[0].shape == (1000, 1000)
  81. assert res[1].shape == (100, 100)
  82. assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
  83. assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
  84. def test_one_hot_input_layer_03():
  85. if torch.cuda.device_count() == 0:
  86. pytest.skip('No CUDA devices on this host')
  87. d = _some_data()
  88. layer = OneHotInputLayer(d)
  89. device = torch.device('cuda:0')
  90. layer = layer.to(device)
  91. print(list(layer.parameters()))
  92. # assert layer.device.type == 'cuda:0'
  93. assert layer.node_reps[0].device == device
  94. assert layer.node_reps[1].device == device
  95. def test_decagon_layer_01():
  96. d = _some_data_with_interactions()
  97. in_layer = InputLayer(d)
  98. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  99. def test_decagon_layer_02():
  100. d = _some_data_with_interactions()
  101. in_layer = OneHotInputLayer(d)
  102. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  103. _ = d_layer() # dummy call
  104. def test_decagon_layer_03():
  105. d = _some_data_with_interactions()
  106. in_layer = OneHotInputLayer(d)
  107. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  108. assert d_layer.data == d
  109. assert d_layer.previous_layer == in_layer
  110. assert d_layer.input_dim == [ 1000, 100 ]
  111. assert not d_layer.is_sparse
  112. assert d_layer.keep_prob == 1.
  113. assert d_layer.rel_activation(0.5) == 0.5
  114. x = torch.tensor([-1, 0, 0.5, 1])
  115. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  116. assert len(d_layer.next_layer_repr) == 2
  117. for i in range(2):
  118. assert len(d_layer.next_layer_repr[i]) == 2
  119. assert isinstance(d_layer.next_layer_repr[i], list)
  120. assert isinstance(d_layer.next_layer_repr[i][0], tuple)
  121. assert isinstance(d_layer.next_layer_repr[i][0][0], list)
  122. assert isinstance(d_layer.next_layer_repr[i][0][1], int)
  123. assert all([
  124. isinstance(dgca, DropoutGraphConvActivation) \
  125. for dgca in d_layer.next_layer_repr[i][0][0]
  126. ])
  127. assert all([
  128. dgca.output_dim == 32 \
  129. for dgca in d_layer.next_layer_repr[i][0][0]
  130. ])
  131. def test_decagon_layer_04():
  132. # check if it is equivalent to MultiDGCA, as it should be
  133. d = Data()
  134. d.add_node_type('Dummy', 100)
  135. d.add_relation_type('Dummy Relation', 0, 0,
  136. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  137. in_layer = OneHotInputLayer(d)
  138. multi_dgca = SparseMultiDGCA([10], 32,
  139. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  140. keep_prob=1., activation=lambda x: x)
  141. d_layer = DecagonLayer(d, in_layer, output_dim=32,
  142. keep_prob=1., rel_activation=lambda x: x,
  143. layer_activation=lambda x: x)
  144. assert isinstance(d_layer.next_layer_repr[0][0][0][0],
  145. DropoutGraphConvActivation)
  146. weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight
  147. assert isinstance(weight, torch.Tensor)
  148. assert len(multi_dgca.sparse_dgca) == 1
  149. assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
  150. multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
  151. out_d_layer = d_layer()
  152. out_multi_dgca = multi_dgca(in_layer())
  153. assert isinstance(out_d_layer, list)
  154. assert len(out_d_layer) == 1
  155. assert torch.all(out_d_layer[0] == out_multi_dgca)
  156. def test_decagon_layer_05():
  157. # check if it is equivalent to MultiDGCA, as it should be
  158. # this time for two relations, same edge type
  159. d = Data()
  160. d.add_node_type('Dummy', 100)
  161. d.add_relation_type('Dummy Relation 1', 0, 0,
  162. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  163. d.add_relation_type('Dummy Relation 2', 0, 0,
  164. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  165. in_layer = OneHotInputLayer(d)
  166. multi_dgca = SparseMultiDGCA([100, 100], 32,
  167. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  168. keep_prob=1., activation=lambda x: x)
  169. d_layer = DecagonLayer(d, in_layer, output_dim=32,
  170. keep_prob=1., rel_activation=lambda x: x,
  171. layer_activation=lambda x: x)
  172. assert all([
  173. isinstance(dgca, DropoutGraphConvActivation) \
  174. for dgca in d_layer.next_layer_repr[0][0][0]
  175. ])
  176. weight = [ dgca.graph_conv.weight \
  177. for dgca in d_layer.next_layer_repr[0][0][0] ]
  178. assert all([
  179. isinstance(w, torch.Tensor) \
  180. for w in weight
  181. ])
  182. assert len(multi_dgca.sparse_dgca) == 2
  183. for i in range(2):
  184. assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation)
  185. for i in range(2):
  186. multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i]
  187. out_d_layer = d_layer()
  188. x = in_layer()
  189. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  190. assert isinstance(out_d_layer, list)
  191. assert len(out_d_layer) == 1
  192. assert torch.all(out_d_layer[0] == out_multi_dgca)