IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

test_convlayer.py 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _make_symmetric(x: torch.Tensor):
  12. x = (x + x.transpose(0, 1)) / 2
  13. return x
  14. def _symmetric_random(n_rows, n_columns):
  15. return _make_symmetric(torch.rand((n_rows, n_columns),
  16. dtype=torch.float32).round())
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  22. fam.add_relation_type('Target',
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  25. fam.add_relation_type('Interaction',
  26. _symmetric_random(1000, 1000))
  27. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  28. fam.add_relation_type('Side Effect: Nausea',
  29. _symmetric_random(100, 100))
  30. fam.add_relation_type('Side Effect: Infertility',
  31. _symmetric_random(100, 100))
  32. fam.add_relation_type('Side Effect: Death',
  33. _symmetric_random(100, 100))
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  39. seq = torch.nn.Sequential(in_layer, d_layer)
  40. _ = seq(None) # dummy call
  41. def test_decagon_layer_02():
  42. d = _some_data_with_interactions()
  43. in_layer = OneHotInputLayer(d)
  44. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  45. seq = torch.nn.Sequential(in_layer, d_layer)
  46. _ = seq(None) # dummy call
  47. def test_decagon_layer_03():
  48. d = _some_data_with_interactions()
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. assert d_layer.input_dim == [ 1000, 100 ]
  52. assert d_layer.output_dim == [ 32, 32 ]
  53. assert d_layer.data == d
  54. assert d_layer.keep_prob == 1.
  55. assert d_layer.rel_activation(0.5) == 0.5
  56. x = torch.tensor([-1, 0, 0.5, 1])
  57. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  58. assert not d_layer.is_sparse
  59. assert len(d_layer.next_layer_repr) == 2
  60. for i in range(2):
  61. assert len(d_layer.next_layer_repr[i]) == 2
  62. assert isinstance(d_layer.next_layer_repr[i], torch.nn.ModuleList)
  63. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  64. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  65. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, torch.nn.ModuleList)
  66. assert all([
  67. isinstance(dgca, DropoutGraphConvActivation) \
  68. for dgca in d_layer.next_layer_repr[i][0].convolutions
  69. ])
  70. assert all([
  71. dgca.output_dim == 32 \
  72. for dgca in d_layer.next_layer_repr[i][0].convolutions
  73. ])
  74. def test_decagon_layer_04():
  75. # check if it is equivalent to MultiDGCA, as it should be
  76. d = Data()
  77. d.add_node_type('Dummy', 100)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  79. fam.add_relation_type('Dummy Relation',
  80. _symmetric_random(100, 100).to_sparse())
  81. in_layer = OneHotInputLayer(d)
  82. multi_dgca = MultiDGCA([10], 32,
  83. [r.adjacency_matrix for r in fam.relation_types],
  84. keep_prob=1., activation=lambda x: x)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  86. keep_prob=1., rel_activation=lambda x: x,
  87. layer_activation=lambda x: x)
  88. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  89. DropoutGraphConvActivation)
  90. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  91. assert isinstance(weight, torch.Tensor)
  92. assert len(multi_dgca.dgca) == 1
  93. assert isinstance(multi_dgca.dgca[0],
  94. decagon_pytorch.convolve.DropoutGraphConvActivation)
  95. multi_dgca.dgca[0].graph_conv.weight = weight
  96. seq = torch.nn.Sequential(in_layer, d_layer)
  97. out_d_layer = seq(None)
  98. out_multi_dgca = multi_dgca(in_layer(None))
  99. assert isinstance(out_d_layer, list)
  100. assert len(out_d_layer) == 1
  101. assert torch.all(out_d_layer[0] == out_multi_dgca)
  102. def test_decagon_layer_05():
  103. # check if it is equivalent to MultiDGCA, as it should be
  104. # this time for two relations, same edge type
  105. d = Data()
  106. d.add_node_type('Dummy', 100)
  107. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  108. fam.add_relation_type('Dummy Relation 1',
  109. _symmetric_random(100, 100).to_sparse())
  110. fam.add_relation_type('Dummy Relation 2',
  111. _symmetric_random(100, 100).to_sparse())
  112. in_layer = OneHotInputLayer(d)
  113. multi_dgca = MultiDGCA([100, 100], 32,
  114. [r.adjacency_matrix for r in fam.relation_types],
  115. keep_prob=1., activation=lambda x: x)
  116. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  117. keep_prob=1., rel_activation=lambda x: x,
  118. layer_activation=lambda x: x)
  119. assert all([
  120. isinstance(dgca, DropoutGraphConvActivation) \
  121. for dgca in d_layer.next_layer_repr[0][0].convolutions
  122. ])
  123. weight = [ dgca.graph_conv.weight \
  124. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  125. assert all([
  126. isinstance(w, torch.Tensor) \
  127. for w in weight
  128. ])
  129. assert len(multi_dgca.dgca) == 2
  130. for i in range(2):
  131. assert isinstance(multi_dgca.dgca[i],
  132. decagon_pytorch.convolve.DropoutGraphConvActivation)
  133. for i in range(2):
  134. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  135. seq = torch.nn.Sequential(in_layer, d_layer)
  136. out_d_layer = seq(None)
  137. x = in_layer(None)
  138. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  139. assert isinstance(out_d_layer, list)
  140. assert len(out_d_layer) == 1
  141. assert torch.all(out_d_layer[0] == out_multi_dgca)
  142. class Dummy1(torch.nn.Module):
  143. def __init__(self, **kwargs):
  144. super().__init__(**kwargs)
  145. self.whatever = torch.nn.Parameter(torch.rand((10, 10)))
  146. class Dummy2(torch.nn.Module):
  147. def __init__(self, **kwargs):
  148. super().__init__(**kwargs)
  149. self.dummy_1 = Dummy1()
  150. class Dummy3(torch.nn.Module):
  151. def __init__(self, **kwargs):
  152. super().__init__(**kwargs)
  153. self.dummy_1 = [ Dummy1() ]
  154. class Dummy4(torch.nn.Module):
  155. def __init__(self, **kwargs):
  156. super().__init__(**kwargs)
  157. self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
  158. class Dummy5(torch.nn.Module):
  159. def __init__(self, **kwargs):
  160. super().__init__(**kwargs)
  161. self.dummy_1 = [ torch.nn.ModuleList([ Dummy1() ]) ]
  162. class Dummy6(torch.nn.Module):
  163. def __init__(self, **kwargs):
  164. super().__init__(**kwargs)
  165. self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList([ Dummy1() ]) ])
  166. class Dummy7(torch.nn.Module):
  167. def __init__(self, **kwargs):
  168. super().__init__(**kwargs)
  169. self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList() ])
  170. self.dummy_1[0].append(Dummy1())
  171. def test_module_nesting_01():
  172. if torch.cuda.device_count() == 0:
  173. pytest.skip('No CUDA support on this host')
  174. device = torch.device('cuda:0')
  175. dummy_2 = Dummy2()
  176. dummy_2 = dummy_2.to(device)
  177. assert dummy_2.dummy_1.whatever.device == device
  178. def test_module_nesting_02():
  179. if torch.cuda.device_count() == 0:
  180. pytest.skip('No CUDA support on this host')
  181. device = torch.device('cuda:0')
  182. dummy_3 = Dummy3()
  183. dummy_3 = dummy_3.to(device)
  184. assert dummy_3.dummy_1[0].whatever.device != device
  185. def test_module_nesting_03():
  186. if torch.cuda.device_count() == 0:
  187. pytest.skip('No CUDA support on this host')
  188. device = torch.device('cuda:0')
  189. dummy_4 = Dummy4()
  190. dummy_4 = dummy_4.to(device)
  191. assert dummy_4.dummy_1[0].whatever.device == device
  192. def test_module_nesting_04():
  193. if torch.cuda.device_count() == 0:
  194. pytest.skip('No CUDA support on this host')
  195. device = torch.device('cuda:0')
  196. dummy_5 = Dummy5()
  197. dummy_5 = dummy_5.to(device)
  198. assert dummy_5.dummy_1[0][0].whatever.device != device
  199. def test_module_nesting_05():
  200. if torch.cuda.device_count() == 0:
  201. pytest.skip('No CUDA support on this host')
  202. device = torch.device('cuda:0')
  203. dummy_6 = Dummy6()
  204. dummy_6 = dummy_6.to(device)
  205. assert dummy_6.dummy_1[0][0].whatever.device == device
  206. def test_module_nesting_06():
  207. if torch.cuda.device_count() == 0:
  208. pytest.skip('No CUDA support on this host')
  209. device = torch.device('cuda:0')
  210. dummy_7 = Dummy7()
  211. dummy_7 = dummy_7.to(device)
  212. assert dummy_7.dummy_1[0][0].whatever.device == device
  213. def test_parameter_count_01():
  214. d = Data()
  215. d.add_node_type('Dummy', 100)
  216. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  217. fam.add_relation_type('Dummy Relation 1',
  218. _symmetric_random(100, 100).to_sparse())
  219. fam.add_relation_type('Dummy Relation 2',
  220. _symmetric_random(100, 100).to_sparse())
  221. in_layer = OneHotInputLayer(d)
  222. assert len(list(in_layer.parameters())) == 1
  223. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  224. keep_prob=1., rel_activation=lambda x: x,
  225. layer_activation=lambda x: x)
  226. assert len(list(d_layer.parameters())) == 2