IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

168 lignes
5.7KB

  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _some_data_with_interactions():
  12. d = Data()
  13. d.add_node_type('Gene', 1000)
  14. d.add_node_type('Drug', 100)
  15. d.add_relation_type('Target', 1, 0,
  16. torch.rand((100, 1000), dtype=torch.float32).round())
  17. d.add_relation_type('Interaction', 0, 0,
  18. torch.rand((1000, 1000), dtype=torch.float32).round())
  19. d.add_relation_type('Side Effect: Nausea', 1, 1,
  20. torch.rand((100, 100), dtype=torch.float32).round())
  21. d.add_relation_type('Side Effect: Infertility', 1, 1,
  22. torch.rand((100, 100), dtype=torch.float32).round())
  23. d.add_relation_type('Side Effect: Death', 1, 1,
  24. torch.rand((100, 100), dtype=torch.float32).round())
  25. return d
  26. def test_decagon_layer_01():
  27. d = _some_data_with_interactions()
  28. in_layer = InputLayer(d)
  29. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  30. seq = torch.nn.Sequential(in_layer, d_layer)
  31. _ = seq(None) # dummy call
  32. def test_decagon_layer_02():
  33. d = _some_data_with_interactions()
  34. in_layer = OneHotInputLayer(d)
  35. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  36. seq = torch.nn.Sequential(in_layer, d_layer)
  37. _ = seq(None) # dummy call
  38. def test_decagon_layer_03():
  39. d = _some_data_with_interactions()
  40. in_layer = OneHotInputLayer(d)
  41. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  42. assert d_layer.input_dim == [ 1000, 100 ]
  43. assert d_layer.output_dim == [ 32, 32 ]
  44. assert d_layer.data == d
  45. assert d_layer.keep_prob == 1.
  46. assert d_layer.rel_activation(0.5) == 0.5
  47. x = torch.tensor([-1, 0, 0.5, 1])
  48. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  49. assert not d_layer.is_sparse
  50. assert len(d_layer.next_layer_repr) == 2
  51. for i in range(2):
  52. assert len(d_layer.next_layer_repr[i]) == 2
  53. assert isinstance(d_layer.next_layer_repr[i], list)
  54. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  55. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  56. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
  57. assert all([
  58. isinstance(dgca, DropoutGraphConvActivation) \
  59. for dgca in d_layer.next_layer_repr[i][0].convolutions
  60. ])
  61. assert all([
  62. dgca.output_dim == 32 \
  63. for dgca in d_layer.next_layer_repr[i][0].convolutions
  64. ])
  65. def test_decagon_layer_04():
  66. # check if it is equivalent to MultiDGCA, as it should be
  67. d = Data()
  68. d.add_node_type('Dummy', 100)
  69. d.add_relation_type('Dummy Relation', 0, 0,
  70. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  71. in_layer = OneHotInputLayer(d)
  72. multi_dgca = MultiDGCA([10], 32,
  73. [r.adjacency_matrix for r in d.relation_types[0][0]],
  74. keep_prob=1., activation=lambda x: x)
  75. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  76. keep_prob=1., rel_activation=lambda x: x,
  77. layer_activation=lambda x: x)
  78. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  79. DropoutGraphConvActivation)
  80. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  81. assert isinstance(weight, torch.Tensor)
  82. assert len(multi_dgca.dgca) == 1
  83. assert isinstance(multi_dgca.dgca[0],
  84. decagon_pytorch.convolve.DropoutGraphConvActivation)
  85. multi_dgca.dgca[0].graph_conv.weight = weight
  86. seq = torch.nn.Sequential(in_layer, d_layer)
  87. out_d_layer = seq(None)
  88. out_multi_dgca = multi_dgca(in_layer(None))
  89. assert isinstance(out_d_layer, list)
  90. assert len(out_d_layer) == 1
  91. assert torch.all(out_d_layer[0] == out_multi_dgca)
  92. def test_decagon_layer_05():
  93. # check if it is equivalent to MultiDGCA, as it should be
  94. # this time for two relations, same edge type
  95. d = Data()
  96. d.add_node_type('Dummy', 100)
  97. d.add_relation_type('Dummy Relation 1', 0, 0,
  98. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  99. d.add_relation_type('Dummy Relation 2', 0, 0,
  100. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  101. in_layer = OneHotInputLayer(d)
  102. multi_dgca = MultiDGCA([100, 100], 32,
  103. [r.adjacency_matrix for r in d.relation_types[0][0]],
  104. keep_prob=1., activation=lambda x: x)
  105. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  106. keep_prob=1., rel_activation=lambda x: x,
  107. layer_activation=lambda x: x)
  108. assert all([
  109. isinstance(dgca, DropoutGraphConvActivation) \
  110. for dgca in d_layer.next_layer_repr[0][0].convolutions
  111. ])
  112. weight = [ dgca.graph_conv.weight \
  113. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  114. assert all([
  115. isinstance(w, torch.Tensor) \
  116. for w in weight
  117. ])
  118. assert len(multi_dgca.dgca) == 2
  119. for i in range(2):
  120. assert isinstance(multi_dgca.dgca[i],
  121. decagon_pytorch.convolve.DropoutGraphConvActivation)
  122. for i in range(2):
  123. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  124. seq = torch.nn.Sequential(in_layer, d_layer)
  125. out_d_layer = seq(None)
  126. x = in_layer(None)
  127. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  128. assert isinstance(out_d_layer, list)
  129. assert len(out_d_layer) == 1
  130. assert torch.all(out_d_layer[0] == out_multi_dgca)