IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 line
5.7KB

  1. from decagon_pytorch.layer import InputLayer, \
  2. OneHotInputLayer, \
  3. DecagonLayer
  4. from decagon_pytorch.data import Data
  5. import torch
  6. import pytest
  7. from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
  8. SparseMultiDGCA, \
  9. DropoutGraphConvActivation
  10. def _some_data():
  11. d = Data()
  12. d.add_node_type('Gene', 1000)
  13. d.add_node_type('Drug', 100)
  14. d.add_relation_type('Target', 1, 0, None)
  15. d.add_relation_type('Interaction', 0, 0, None)
  16. d.add_relation_type('Side Effect: Nausea', 1, 1, None)
  17. d.add_relation_type('Side Effect: Infertility', 1, 1, None)
  18. d.add_relation_type('Side Effect: Death', 1, 1, None)
  19. return d
  20. def _some_data_with_interactions():
  21. d = Data()
  22. d.add_node_type('Gene', 1000)
  23. d.add_node_type('Drug', 100)
  24. d.add_relation_type('Target', 1, 0,
  25. torch.rand((100, 1000), dtype=torch.float32).round())
  26. d.add_relation_type('Interaction', 0, 0,
  27. torch.rand((1000, 1000), dtype=torch.float32).round())
  28. d.add_relation_type('Side Effect: Nausea', 1, 1,
  29. torch.rand((100, 100), dtype=torch.float32).round())
  30. d.add_relation_type('Side Effect: Infertility', 1, 1,
  31. torch.rand((100, 100), dtype=torch.float32).round())
  32. d.add_relation_type('Side Effect: Death', 1, 1,
  33. torch.rand((100, 100), dtype=torch.float32).round())
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  39. def test_decagon_layer_02():
  40. d = _some_data_with_interactions()
  41. in_layer = OneHotInputLayer(d)
  42. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  43. _ = d_layer() # dummy call
  44. def test_decagon_layer_03():
  45. d = _some_data_with_interactions()
  46. in_layer = OneHotInputLayer(d)
  47. d_layer = DecagonLayer(d, in_layer, output_dim=32)
  48. assert d_layer.data == d
  49. assert d_layer.previous_layer == in_layer
  50. assert d_layer.input_dim == [ 1000, 100 ]
  51. assert not d_layer.is_sparse
  52. assert d_layer.keep_prob == 1.
  53. assert d_layer.rel_activation(0.5) == 0.5
  54. x = torch.tensor([-1, 0, 0.5, 1])
  55. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  56. assert len(d_layer.next_layer_repr) == 2
  57. for i in range(2):
  58. assert len(d_layer.next_layer_repr[i]) == 2
  59. assert isinstance(d_layer.next_layer_repr[i], list)
  60. assert isinstance(d_layer.next_layer_repr[i][0], tuple)
  61. assert isinstance(d_layer.next_layer_repr[i][0][0], list)
  62. assert isinstance(d_layer.next_layer_repr[i][0][1], int)
  63. assert all([
  64. isinstance(dgca, DropoutGraphConvActivation) \
  65. for dgca in d_layer.next_layer_repr[i][0][0]
  66. ])
  67. assert all([
  68. dgca.output_dim == 32 \
  69. for dgca in d_layer.next_layer_repr[i][0][0]
  70. ])
  71. def test_decagon_layer_04():
  72. # check if it is equivalent to MultiDGCA, as it should be
  73. d = Data()
  74. d.add_node_type('Dummy', 100)
  75. d.add_relation_type('Dummy Relation', 0, 0,
  76. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  77. in_layer = OneHotInputLayer(d)
  78. multi_dgca = SparseMultiDGCA([10], 32,
  79. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  80. keep_prob=1., activation=lambda x: x)
  81. d_layer = DecagonLayer(d, in_layer, output_dim=32,
  82. keep_prob=1., rel_activation=lambda x: x,
  83. layer_activation=lambda x: x)
  84. assert isinstance(d_layer.next_layer_repr[0][0][0][0],
  85. DropoutGraphConvActivation)
  86. weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight
  87. assert isinstance(weight, torch.Tensor)
  88. assert len(multi_dgca.sparse_dgca) == 1
  89. assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
  90. multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
  91. out_d_layer = d_layer()
  92. out_multi_dgca = multi_dgca(in_layer())
  93. assert isinstance(out_d_layer, list)
  94. assert len(out_d_layer) == 1
  95. assert torch.all(out_d_layer[0] == out_multi_dgca)
  96. def test_decagon_layer_05():
  97. # check if it is equivalent to MultiDGCA, as it should be
  98. # this time for two relations, same edge type
  99. d = Data()
  100. d.add_node_type('Dummy', 100)
  101. d.add_relation_type('Dummy Relation 1', 0, 0,
  102. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  103. d.add_relation_type('Dummy Relation 2', 0, 0,
  104. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  105. in_layer = OneHotInputLayer(d)
  106. multi_dgca = SparseMultiDGCA([100, 100], 32,
  107. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  108. keep_prob=1., activation=lambda x: x)
  109. d_layer = DecagonLayer(d, in_layer, output_dim=32,
  110. keep_prob=1., rel_activation=lambda x: x,
  111. layer_activation=lambda x: x)
  112. assert all([
  113. isinstance(dgca, DropoutGraphConvActivation) \
  114. for dgca in d_layer.next_layer_repr[0][0][0]
  115. ])
  116. weight = [ dgca.graph_conv.weight \
  117. for dgca in d_layer.next_layer_repr[0][0][0] ]
  118. assert all([
  119. isinstance(w, torch.Tensor) \
  120. for w in weight
  121. ])
  122. assert len(multi_dgca.sparse_dgca) == 2
  123. for i in range(2):
  124. assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation)
  125. for i in range(2):
  126. multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i]
  127. out_d_layer = d_layer()
  128. x = in_layer()
  129. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  130. assert isinstance(out_d_layer, list)
  131. assert len(out_d_layer) == 1
  132. assert torch.all(out_d_layer[0] == out_multi_dgca)