IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

test_convlayer.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _make_symmetric(x: torch.Tensor):
  12. x = (x + x.transpose(0, 1)) / 2
  13. return x
  14. def _symmetric_random(n_rows, n_columns):
  15. return _make_symmetric(torch.rand((n_rows, n_columns),
  16. dtype=torch.float32).round())
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  22. fam.add_relation_type('Target', 1, 0,
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  25. fam.add_relation_type('Interaction', 0, 0,
  26. _symmetric_random(1000, 1000))
  27. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  28. fam.add_relation_type('Side Effect: Nausea', 1, 1,
  29. _symmetric_random(100, 100))
  30. fam.add_relation_type('Side Effect: Infertility', 1, 1,
  31. _symmetric_random(100, 100))
  32. fam.add_relation_type('Side Effect: Death', 1, 1,
  33. _symmetric_random(100, 100))
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  39. seq = torch.nn.Sequential(in_layer, d_layer)
  40. _ = seq(None) # dummy call
  41. def test_decagon_layer_02():
  42. d = _some_data_with_interactions()
  43. in_layer = OneHotInputLayer(d)
  44. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  45. seq = torch.nn.Sequential(in_layer, d_layer)
  46. _ = seq(None) # dummy call
  47. def test_decagon_layer_03():
  48. d = _some_data_with_interactions()
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. assert d_layer.input_dim == [ 1000, 100 ]
  52. assert d_layer.output_dim == [ 32, 32 ]
  53. assert d_layer.data == d
  54. assert d_layer.keep_prob == 1.
  55. assert d_layer.rel_activation(0.5) == 0.5
  56. x = torch.tensor([-1, 0, 0.5, 1])
  57. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  58. assert not d_layer.is_sparse
  59. assert len(d_layer.next_layer_repr) == 2
  60. for i in range(2):
  61. assert len(d_layer.next_layer_repr[i]) == 2
  62. assert isinstance(d_layer.next_layer_repr[i], list)
  63. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  64. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  65. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
  66. assert all([
  67. isinstance(dgca, DropoutGraphConvActivation) \
  68. for dgca in d_layer.next_layer_repr[i][0].convolutions
  69. ])
  70. assert all([
  71. dgca.output_dim == 32 \
  72. for dgca in d_layer.next_layer_repr[i][0].convolutions
  73. ])
  74. def test_decagon_layer_04():
  75. # check if it is equivalent to MultiDGCA, as it should be
  76. d = Data()
  77. d.add_node_type('Dummy', 100)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  79. fam.add_relation_type('Dummy Relation', 0, 0,
  80. _symmetric_random(100, 100).to_sparse())
  81. in_layer = OneHotInputLayer(d)
  82. multi_dgca = MultiDGCA([10], 32,
  83. [r.adjacency_matrix for r in fam.relation_types[0, 0]],
  84. keep_prob=1., activation=lambda x: x)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  86. keep_prob=1., rel_activation=lambda x: x,
  87. layer_activation=lambda x: x)
  88. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  89. DropoutGraphConvActivation)
  90. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  91. assert isinstance(weight, torch.Tensor)
  92. assert len(multi_dgca.dgca) == 1
  93. assert isinstance(multi_dgca.dgca[0],
  94. decagon_pytorch.convolve.DropoutGraphConvActivation)
  95. multi_dgca.dgca[0].graph_conv.weight = weight
  96. seq = torch.nn.Sequential(in_layer, d_layer)
  97. out_d_layer = seq(None)
  98. out_multi_dgca = multi_dgca(in_layer(None))
  99. assert isinstance(out_d_layer, list)
  100. assert len(out_d_layer) == 1
  101. assert torch.all(out_d_layer[0] == out_multi_dgca)
  102. def test_decagon_layer_05():
  103. # check if it is equivalent to MultiDGCA, as it should be
  104. # this time for two relations, same edge type
  105. d = Data()
  106. d.add_node_type('Dummy', 100)
  107. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  108. fam.add_relation_type('Dummy Relation 1', 0, 0,
  109. _symmetric_random(100, 100).to_sparse())
  110. fam.add_relation_type('Dummy Relation 2', 0, 0,
  111. _symmetric_random(100, 100).to_sparse())
  112. in_layer = OneHotInputLayer(d)
  113. multi_dgca = MultiDGCA([100, 100], 32,
  114. [r.adjacency_matrix for r in fam.relation_types[0, 0]],
  115. keep_prob=1., activation=lambda x: x)
  116. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  117. keep_prob=1., rel_activation=lambda x: x,
  118. layer_activation=lambda x: x)
  119. assert all([
  120. isinstance(dgca, DropoutGraphConvActivation) \
  121. for dgca in d_layer.next_layer_repr[0][0].convolutions
  122. ])
  123. weight = [ dgca.graph_conv.weight \
  124. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  125. assert all([
  126. isinstance(w, torch.Tensor) \
  127. for w in weight
  128. ])
  129. assert len(multi_dgca.dgca) == 2
  130. for i in range(2):
  131. assert isinstance(multi_dgca.dgca[i],
  132. decagon_pytorch.convolve.DropoutGraphConvActivation)
  133. for i in range(2):
  134. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  135. seq = torch.nn.Sequential(in_layer, d_layer)
  136. out_d_layer = seq(None)
  137. x = in_layer(None)
  138. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  139. assert isinstance(out_d_layer, list)
  140. assert len(out_d_layer) == 1
  141. assert torch.all(out_d_layer[0] == out_multi_dgca)