IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.7KB

  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _some_data_with_interactions():
  12. d = Data()
  13. d.add_node_type('Gene', 1000)
  14. d.add_node_type('Drug', 100)
  15. d.add_relation_type('Target', 1, 0,
  16. torch.rand((100, 1000), dtype=torch.float32).round())
  17. d.add_relation_type('Interaction', 0, 0,
  18. torch.rand((1000, 1000), dtype=torch.float32).round())
  19. d.add_relation_type('Side Effect: Nausea', 1, 1,
  20. torch.rand((100, 100), dtype=torch.float32).round())
  21. d.add_relation_type('Side Effect: Infertility', 1, 1,
  22. torch.rand((100, 100), dtype=torch.float32).round())
  23. d.add_relation_type('Side Effect: Death', 1, 1,
  24. torch.rand((100, 100), dtype=torch.float32).round())
  25. return d
  26. def test_decagon_layer_01():
  27. d = _some_data_with_interactions()
  28. in_layer = InputLayer(d)
  29. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  30. seq = torch.nn.Sequential(in_layer, d_layer)
  31. _ = seq(None) # dummy call
  32. def test_decagon_layer_02():
  33. d = _some_data_with_interactions()
  34. in_layer = OneHotInputLayer(d)
  35. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  36. seq = torch.nn.Sequential(in_layer, d_layer)
  37. _ = seq(None) # dummy call
  38. def test_decagon_layer_03():
  39. d = _some_data_with_interactions()
  40. in_layer = OneHotInputLayer(d)
  41. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  42. assert d_layer.input_dim == [ 1000, 100 ]
  43. assert d_layer.output_dim == [ 32, 32 ]
  44. assert d_layer.data == d
  45. assert d_layer.keep_prob == 1.
  46. assert d_layer.rel_activation(0.5) == 0.5
  47. x = torch.tensor([-1, 0, 0.5, 1])
  48. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  49. assert not d_layer.is_sparse
  50. assert len(d_layer.next_layer_repr) == 2
  51. for i in range(2):
  52. assert len(d_layer.next_layer_repr[i]) == 2
  53. assert isinstance(d_layer.next_layer_repr[i], list)
  54. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  55. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  56. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
  57. assert all([
  58. isinstance(dgca, DropoutGraphConvActivation) \
  59. for dgca in d_layer.next_layer_repr[i][0].convolutions
  60. ])
  61. assert all([
  62. dgca.output_dim == 32 \
  63. for dgca in d_layer.next_layer_repr[i][0].convolutions
  64. ])
  65. def test_decagon_layer_04():
  66. # check if it is equivalent to MultiDGCA, as it should be
  67. d = Data()
  68. d.add_node_type('Dummy', 100)
  69. d.add_relation_type('Dummy Relation', 0, 0,
  70. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  71. in_layer = OneHotInputLayer(d)
  72. multi_dgca = MultiDGCA([10], 32,
  73. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  74. keep_prob=1., activation=lambda x: x)
  75. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  76. keep_prob=1., rel_activation=lambda x: x,
  77. layer_activation=lambda x: x)
  78. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  79. DropoutGraphConvActivation)
  80. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  81. assert isinstance(weight, torch.Tensor)
  82. assert len(multi_dgca.dgca) == 1
  83. assert isinstance(multi_dgca.dgca[0],
  84. decagon_pytorch.convolve.DropoutGraphConvActivation)
  85. multi_dgca.dgca[0].graph_conv.weight = weight
  86. seq = torch.nn.Sequential(in_layer, d_layer)
  87. out_d_layer = seq(None)
  88. out_multi_dgca = multi_dgca(in_layer(None))
  89. assert isinstance(out_d_layer, list)
  90. assert len(out_d_layer) == 1
  91. assert torch.all(out_d_layer[0] == out_multi_dgca)
  92. def test_decagon_layer_05():
  93. # check if it is equivalent to MultiDGCA, as it should be
  94. # this time for two relations, same edge type
  95. d = Data()
  96. d.add_node_type('Dummy', 100)
  97. d.add_relation_type('Dummy Relation 1', 0, 0,
  98. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  99. d.add_relation_type('Dummy Relation 2', 0, 0,
  100. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  101. in_layer = OneHotInputLayer(d)
  102. multi_dgca = MultiDGCA([100, 100], 32,
  103. [r.adjacency_matrix for r in d.relation_types[0, 0]],
  104. keep_prob=1., activation=lambda x: x)
  105. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  106. keep_prob=1., rel_activation=lambda x: x,
  107. layer_activation=lambda x: x)
  108. assert all([
  109. isinstance(dgca, DropoutGraphConvActivation) \
  110. for dgca in d_layer.next_layer_repr[0][0].convolutions
  111. ])
  112. weight = [ dgca.graph_conv.weight \
  113. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  114. assert all([
  115. isinstance(w, torch.Tensor) \
  116. for w in weight
  117. ])
  118. assert len(multi_dgca.dgca) == 2
  119. for i in range(2):
  120. assert isinstance(multi_dgca.dgca[i],
  121. decagon_pytorch.convolve.DropoutGraphConvActivation)
  122. for i in range(2):
  123. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  124. seq = torch.nn.Sequential(in_layer, d_layer)
  125. out_d_layer = seq(None)
  126. x = in_layer(None)
  127. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  128. assert isinstance(out_d_layer, list)
  129. assert len(out_d_layer) == 1
  130. assert torch.all(out_d_layer[0] == out_multi_dgca)