IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

186 Zeilen
6.0KB

  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _make_symmetric(x: torch.Tensor):
  12. x = (x + x.transpose(0, 1)) / 2
  13. return x
  14. def _symmetric_random(n_rows, n_columns):
  15. return _make_symmetric(torch.rand((n_rows, n_columns),
  16. dtype=torch.float32).round())
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  22. fam.add_relation_type('Target', 1, 0,
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  25. fam.add_relation_type('Interaction', 0, 0,
  26. _symmetric_random(1000, 1000))
  27. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  28. fam.add_relation_type('Side Effect: Nausea', 1, 1,
  29. _symmetric_random(100, 100))
  30. fam.add_relation_type('Side Effect: Infertility', 1, 1,
  31. _symmetric_random(100, 100))
  32. fam.add_relation_type('Side Effect: Death', 1, 1,
  33. _symmetric_random(100, 100))
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  39. seq = torch.nn.Sequential(in_layer, d_layer)
  40. _ = seq(None) # dummy call
  41. def test_decagon_layer_02():
  42. d = _some_data_with_interactions()
  43. in_layer = OneHotInputLayer(d)
  44. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  45. seq = torch.nn.Sequential(in_layer, d_layer)
  46. _ = seq(None) # dummy call
  47. def test_decagon_layer_03():
  48. d = _some_data_with_interactions()
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. assert d_layer.input_dim == [ 1000, 100 ]
  52. assert d_layer.output_dim == [ 32, 32 ]
  53. assert d_layer.data == d
  54. assert d_layer.keep_prob == 1.
  55. assert d_layer.rel_activation(0.5) == 0.5
  56. x = torch.tensor([-1, 0, 0.5, 1])
  57. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  58. assert not d_layer.is_sparse
  59. assert len(d_layer.next_layer_repr) == 2
  60. for i in range(2):
  61. assert len(d_layer.next_layer_repr[i]) == 2
  62. assert isinstance(d_layer.next_layer_repr[i], list)
  63. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  64. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  65. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
  66. assert all([
  67. isinstance(dgca, DropoutGraphConvActivation) \
  68. for dgca in d_layer.next_layer_repr[i][0].convolutions
  69. ])
  70. assert all([
  71. dgca.output_dim == 32 \
  72. for dgca in d_layer.next_layer_repr[i][0].convolutions
  73. ])
  74. def test_decagon_layer_04():
  75. # check if it is equivalent to MultiDGCA, as it should be
  76. d = Data()
  77. d.add_node_type('Dummy', 100)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  79. fam.add_relation_type('Dummy Relation', 0, 0,
  80. _symmetric_random(100, 100).to_sparse())
  81. in_layer = OneHotInputLayer(d)
  82. multi_dgca = MultiDGCA([10], 32,
  83. [r.adjacency_matrix for r in fam.relation_types[0, 0]],
  84. keep_prob=1., activation=lambda x: x)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  86. keep_prob=1., rel_activation=lambda x: x,
  87. layer_activation=lambda x: x)
  88. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  89. DropoutGraphConvActivation)
  90. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  91. assert isinstance(weight, torch.Tensor)
  92. assert len(multi_dgca.dgca) == 1
  93. assert isinstance(multi_dgca.dgca[0],
  94. decagon_pytorch.convolve.DropoutGraphConvActivation)
  95. multi_dgca.dgca[0].graph_conv.weight = weight
  96. seq = torch.nn.Sequential(in_layer, d_layer)
  97. out_d_layer = seq(None)
  98. out_multi_dgca = multi_dgca(in_layer(None))
  99. assert isinstance(out_d_layer, list)
  100. assert len(out_d_layer) == 1
  101. assert torch.all(out_d_layer[0] == out_multi_dgca)
  102. def test_decagon_layer_05():
  103. # check if it is equivalent to MultiDGCA, as it should be
  104. # this time for two relations, same edge type
  105. d = Data()
  106. d.add_node_type('Dummy', 100)
  107. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  108. fam.add_relation_type('Dummy Relation 1', 0, 0,
  109. _symmetric_random(100, 100).to_sparse())
  110. fam.add_relation_type('Dummy Relation 2', 0, 0,
  111. _symmetric_random(100, 100).to_sparse())
  112. in_layer = OneHotInputLayer(d)
  113. multi_dgca = MultiDGCA([100, 100], 32,
  114. [r.adjacency_matrix for r in fam.relation_types[0, 0]],
  115. keep_prob=1., activation=lambda x: x)
  116. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  117. keep_prob=1., rel_activation=lambda x: x,
  118. layer_activation=lambda x: x)
  119. assert all([
  120. isinstance(dgca, DropoutGraphConvActivation) \
  121. for dgca in d_layer.next_layer_repr[0][0].convolutions
  122. ])
  123. weight = [ dgca.graph_conv.weight \
  124. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  125. assert all([
  126. isinstance(w, torch.Tensor) \
  127. for w in weight
  128. ])
  129. assert len(multi_dgca.dgca) == 2
  130. for i in range(2):
  131. assert isinstance(multi_dgca.dgca[i],
  132. decagon_pytorch.convolve.DropoutGraphConvActivation)
  133. for i in range(2):
  134. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  135. seq = torch.nn.Sequential(in_layer, d_layer)
  136. out_d_layer = seq(None)
  137. x = in_layer(None)
  138. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  139. assert isinstance(out_d_layer, list)
  140. assert len(out_d_layer) == 1
  141. assert torch.all(out_d_layer[0] == out_multi_dgca)