IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

test_convlayer.py 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _make_symmetric(x: torch.Tensor):
  12. x = (x + x.transpose(0, 1)) / 2
  13. return x
  14. def _symmetric_random(n_rows, n_columns):
  15. return _make_symmetric(torch.rand((n_rows, n_columns),
  16. dtype=torch.float32).round())
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  22. fam.add_relation_type('Target',
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  25. fam.add_relation_type('Interaction',
  26. _symmetric_random(1000, 1000))
  27. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  28. fam.add_relation_type('Side Effect: Nausea',
  29. _symmetric_random(100, 100))
  30. fam.add_relation_type('Side Effect: Infertility',
  31. _symmetric_random(100, 100))
  32. fam.add_relation_type('Side Effect: Death',
  33. _symmetric_random(100, 100))
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  39. seq = torch.nn.Sequential(in_layer, d_layer)
  40. _ = seq(None) # dummy call
  41. def test_decagon_layer_02():
  42. d = _some_data_with_interactions()
  43. in_layer = OneHotInputLayer(d)
  44. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  45. seq = torch.nn.Sequential(in_layer, d_layer)
  46. _ = seq(None) # dummy call
  47. def test_decagon_layer_03():
  48. d = _some_data_with_interactions()
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. assert d_layer.input_dim == [ 1000, 100 ]
  52. assert d_layer.output_dim == [ 32, 32 ]
  53. assert d_layer.data == d
  54. assert d_layer.keep_prob == 1.
  55. assert d_layer.rel_activation(0.5) == 0.5
  56. x = torch.tensor([-1, 0, 0.5, 1])
  57. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  58. assert not d_layer.is_sparse
  59. assert len(d_layer.next_layer_repr) == 2
  60. for i in range(2):
  61. assert len(d_layer.next_layer_repr[i]) == 2
  62. assert isinstance(d_layer.next_layer_repr[i], torch.nn.ModuleList)
  63. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  64. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  65. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, torch.nn.ModuleList)
  66. assert all([
  67. isinstance(dgca, DropoutGraphConvActivation) \
  68. for dgca in d_layer.next_layer_repr[i][0].convolutions
  69. ])
  70. assert all([
  71. dgca.output_dim == 32 \
  72. for dgca in d_layer.next_layer_repr[i][0].convolutions
  73. ])
  74. def test_decagon_layer_04():
  75. # check if it is equivalent to MultiDGCA, as it should be
  76. d = Data()
  77. d.add_node_type('Dummy', 100)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  79. fam.add_relation_type('Dummy Relation',
  80. _symmetric_random(100, 100).to_sparse())
  81. in_layer = OneHotInputLayer(d)
  82. multi_dgca = MultiDGCA([10], 32,
  83. [r.adjacency_matrix for r in fam.relation_types],
  84. keep_prob=1., activation=lambda x: x)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  86. keep_prob=1., rel_activation=lambda x: x,
  87. layer_activation=lambda x: x)
  88. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  89. DropoutGraphConvActivation)
  90. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  91. assert isinstance(weight, torch.Tensor)
  92. assert len(multi_dgca.dgca) == 1
  93. assert isinstance(multi_dgca.dgca[0],
  94. decagon_pytorch.convolve.DropoutGraphConvActivation)
  95. multi_dgca.dgca[0].graph_conv.weight = weight
  96. seq = torch.nn.Sequential(in_layer, d_layer)
  97. out_d_layer = seq(None)
  98. out_multi_dgca = multi_dgca(list(in_layer(None)))
  99. assert isinstance(out_d_layer, list)
  100. assert len(out_d_layer) == 1
  101. assert torch.all(out_d_layer[0] == out_multi_dgca)
  102. def test_decagon_layer_05():
  103. # check if it is equivalent to MultiDGCA, as it should be
  104. # this time for two relations, same edge type
  105. d = Data()
  106. d.add_node_type('Dummy', 100)
  107. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  108. fam.add_relation_type('Dummy Relation 1',
  109. _symmetric_random(100, 100).to_sparse())
  110. fam.add_relation_type('Dummy Relation 2',
  111. _symmetric_random(100, 100).to_sparse())
  112. in_layer = OneHotInputLayer(d)
  113. multi_dgca = MultiDGCA([100, 100], 32,
  114. [r.adjacency_matrix for r in fam.relation_types],
  115. keep_prob=1., activation=lambda x: x)
  116. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  117. keep_prob=1., rel_activation=lambda x: x,
  118. layer_activation=lambda x: x)
  119. assert all([
  120. isinstance(dgca, DropoutGraphConvActivation) \
  121. for dgca in d_layer.next_layer_repr[0][0].convolutions
  122. ])
  123. weight = [ dgca.graph_conv.weight \
  124. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  125. assert all([
  126. isinstance(w, torch.Tensor) \
  127. for w in weight
  128. ])
  129. assert len(multi_dgca.dgca) == 2
  130. for i in range(2):
  131. assert isinstance(multi_dgca.dgca[i],
  132. decagon_pytorch.convolve.DropoutGraphConvActivation)
  133. for i in range(2):
  134. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  135. seq = torch.nn.Sequential(in_layer, d_layer)
  136. out_d_layer = seq(None)
  137. x = in_layer(None)
  138. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  139. assert isinstance(out_d_layer, list)
  140. assert len(out_d_layer) == 1
  141. assert torch.all(out_d_layer[0] == out_multi_dgca)
  142. class Dummy1(torch.nn.Module):
  143. def __init__(self, **kwargs):
  144. super().__init__(**kwargs)
  145. self.whatever = torch.nn.Parameter(torch.rand((10, 10)))
  146. class Dummy2(torch.nn.Module):
  147. def __init__(self, **kwargs):
  148. super().__init__(**kwargs)
  149. self.dummy_1 = Dummy1()
  150. class Dummy3(torch.nn.Module):
  151. def __init__(self, **kwargs):
  152. super().__init__(**kwargs)
  153. self.dummy_1 = [ Dummy1() ]
  154. class Dummy4(torch.nn.Module):
  155. def __init__(self, **kwargs):
  156. super().__init__(**kwargs)
  157. self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
  158. class Dummy5(torch.nn.Module):
  159. def __init__(self, **kwargs):
  160. super().__init__(**kwargs)
  161. self.dummy_1 = [ torch.nn.ModuleList([ Dummy1() ]) ]
  162. class Dummy6(torch.nn.Module):
  163. def __init__(self, **kwargs):
  164. super().__init__(**kwargs)
  165. self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList([ Dummy1() ]) ])
  166. class Dummy7(torch.nn.Module):
  167. def __init__(self, **kwargs):
  168. super().__init__(**kwargs)
  169. self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList() ])
  170. self.dummy_1[0].append(Dummy1())
  171. def test_module_nesting_01():
  172. if torch.cuda.device_count() == 0:
  173. pytest.skip('No CUDA support on this host')
  174. device = torch.device('cuda:0')
  175. dummy_2 = Dummy2()
  176. dummy_2 = dummy_2.to(device)
  177. assert dummy_2.dummy_1.whatever.device == device
  178. def test_module_nesting_02():
  179. if torch.cuda.device_count() == 0:
  180. pytest.skip('No CUDA support on this host')
  181. device = torch.device('cuda:0')
  182. dummy_3 = Dummy3()
  183. dummy_3 = dummy_3.to(device)
  184. assert dummy_3.dummy_1[0].whatever.device != device
  185. def test_module_nesting_03():
  186. if torch.cuda.device_count() == 0:
  187. pytest.skip('No CUDA support on this host')
  188. device = torch.device('cuda:0')
  189. dummy_4 = Dummy4()
  190. dummy_4 = dummy_4.to(device)
  191. assert dummy_4.dummy_1[0].whatever.device == device
  192. def test_module_nesting_04():
  193. if torch.cuda.device_count() == 0:
  194. pytest.skip('No CUDA support on this host')
  195. device = torch.device('cuda:0')
  196. dummy_5 = Dummy5()
  197. dummy_5 = dummy_5.to(device)
  198. assert dummy_5.dummy_1[0][0].whatever.device != device
  199. def test_module_nesting_05():
  200. if torch.cuda.device_count() == 0:
  201. pytest.skip('No CUDA support on this host')
  202. device = torch.device('cuda:0')
  203. dummy_6 = Dummy6()
  204. dummy_6 = dummy_6.to(device)
  205. assert dummy_6.dummy_1[0][0].whatever.device == device
  206. def test_module_nesting_06():
  207. if torch.cuda.device_count() == 0:
  208. pytest.skip('No CUDA support on this host')
  209. device = torch.device('cuda:0')
  210. dummy_7 = Dummy7()
  211. dummy_7 = dummy_7.to(device)
  212. assert dummy_7.dummy_1[0][0].whatever.device == device
  213. def test_parameter_count_01():
  214. d = Data()
  215. d.add_node_type('Dummy', 100)
  216. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  217. fam.add_relation_type('Dummy Relation 1',
  218. _symmetric_random(100, 100).to_sparse())
  219. fam.add_relation_type('Dummy Relation 2',
  220. _symmetric_random(100, 100).to_sparse())
  221. in_layer = OneHotInputLayer(d)
  222. assert len(list(in_layer.parameters())) == 1
  223. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  224. keep_prob=1., rel_activation=lambda x: x,
  225. layer_activation=lambda x: x)
  226. assert len(list(d_layer.parameters())) == 2