IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
7.1KB

  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.convlayer import DecagonLayer, \
  4. Convolutions
  5. from icosagon.data import Data
  6. import torch
  7. import pytest
  8. from icosagon.convolve import DropoutGraphConvActivation
  9. from decagon_pytorch.convolve import MultiDGCA
  10. import decagon_pytorch.convolve
  11. def _make_symmetric(x: torch.Tensor):
  12. x = (x + x.transpose(0, 1)) / 2
  13. return x
  14. def _symmetric_random(n_rows, n_columns):
  15. return _make_symmetric(torch.rand((n_rows, n_columns),
  16. dtype=torch.float32).round())
  17. def _some_data_with_interactions():
  18. d = Data()
  19. d.add_node_type('Gene', 1000)
  20. d.add_node_type('Drug', 100)
  21. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  22. fam.add_relation_type('Target',
  23. torch.rand((100, 1000), dtype=torch.float32).round())
  24. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  25. fam.add_relation_type('Interaction',
  26. _symmetric_random(1000, 1000))
  27. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  28. fam.add_relation_type('Side Effect: Nausea',
  29. _symmetric_random(100, 100))
  30. fam.add_relation_type('Side Effect: Infertility',
  31. _symmetric_random(100, 100))
  32. fam.add_relation_type('Side Effect: Death',
  33. _symmetric_random(100, 100))
  34. return d
  35. def test_decagon_layer_01():
  36. d = _some_data_with_interactions()
  37. in_layer = InputLayer(d)
  38. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  39. seq = torch.nn.Sequential(in_layer, d_layer)
  40. _ = seq(None) # dummy call
  41. def test_decagon_layer_02():
  42. d = _some_data_with_interactions()
  43. in_layer = OneHotInputLayer(d)
  44. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  45. seq = torch.nn.Sequential(in_layer, d_layer)
  46. _ = seq(None) # dummy call
  47. def test_decagon_layer_03():
  48. d = _some_data_with_interactions()
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. assert d_layer.input_dim == [ 1000, 100 ]
  52. assert d_layer.output_dim == [ 32, 32 ]
  53. assert d_layer.data == d
  54. assert d_layer.keep_prob == 1.
  55. assert d_layer.rel_activation(0.5) == 0.5
  56. x = torch.tensor([-1, 0, 0.5, 1])
  57. assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
  58. assert not d_layer.is_sparse
  59. assert len(d_layer.next_layer_repr) == 2
  60. for i in range(2):
  61. assert len(d_layer.next_layer_repr[i]) == 2
  62. assert isinstance(d_layer.next_layer_repr[i], list)
  63. assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
  64. assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
  65. assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
  66. assert all([
  67. isinstance(dgca, DropoutGraphConvActivation) \
  68. for dgca in d_layer.next_layer_repr[i][0].convolutions
  69. ])
  70. assert all([
  71. dgca.output_dim == 32 \
  72. for dgca in d_layer.next_layer_repr[i][0].convolutions
  73. ])
  74. def test_decagon_layer_04():
  75. # check if it is equivalent to MultiDGCA, as it should be
  76. d = Data()
  77. d.add_node_type('Dummy', 100)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  79. fam.add_relation_type('Dummy Relation',
  80. _symmetric_random(100, 100).to_sparse())
  81. in_layer = OneHotInputLayer(d)
  82. multi_dgca = MultiDGCA([10], 32,
  83. [r.adjacency_matrix for r in fam.relation_types],
  84. keep_prob=1., activation=lambda x: x)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, d,
  86. keep_prob=1., rel_activation=lambda x: x,
  87. layer_activation=lambda x: x)
  88. assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
  89. DropoutGraphConvActivation)
  90. weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
  91. assert isinstance(weight, torch.Tensor)
  92. assert len(multi_dgca.dgca) == 1
  93. assert isinstance(multi_dgca.dgca[0],
  94. decagon_pytorch.convolve.DropoutGraphConvActivation)
  95. multi_dgca.dgca[0].graph_conv.weight = weight
  96. seq = torch.nn.Sequential(in_layer, d_layer)
  97. out_d_layer = seq(None)
  98. out_multi_dgca = multi_dgca(in_layer(None))
  99. assert isinstance(out_d_layer, list)
  100. assert len(out_d_layer) == 1
  101. assert torch.all(out_d_layer[0] == out_multi_dgca)
  102. def test_decagon_layer_05():
  103. # check if it is equivalent to MultiDGCA, as it should be
  104. # this time for two relations, same edge type
  105. d = Data()
  106. d.add_node_type('Dummy', 100)
  107. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
  108. fam.add_relation_type('Dummy Relation 1',
  109. _symmetric_random(100, 100).to_sparse())
  110. fam.add_relation_type('Dummy Relation 2',
  111. _symmetric_random(100, 100).to_sparse())
  112. in_layer = OneHotInputLayer(d)
  113. multi_dgca = MultiDGCA([100, 100], 32,
  114. [r.adjacency_matrix for r in fam.relation_types],
  115. keep_prob=1., activation=lambda x: x)
  116. d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
  117. keep_prob=1., rel_activation=lambda x: x,
  118. layer_activation=lambda x: x)
  119. assert all([
  120. isinstance(dgca, DropoutGraphConvActivation) \
  121. for dgca in d_layer.next_layer_repr[0][0].convolutions
  122. ])
  123. weight = [ dgca.graph_conv.weight \
  124. for dgca in d_layer.next_layer_repr[0][0].convolutions ]
  125. assert all([
  126. isinstance(w, torch.Tensor) \
  127. for w in weight
  128. ])
  129. assert len(multi_dgca.dgca) == 2
  130. for i in range(2):
  131. assert isinstance(multi_dgca.dgca[i],
  132. decagon_pytorch.convolve.DropoutGraphConvActivation)
  133. for i in range(2):
  134. multi_dgca.dgca[i].graph_conv.weight = weight[i]
  135. seq = torch.nn.Sequential(in_layer, d_layer)
  136. out_d_layer = seq(None)
  137. x = in_layer(None)
  138. out_multi_dgca = multi_dgca([ x[0], x[0] ])
  139. assert isinstance(out_d_layer, list)
  140. assert len(out_d_layer) == 1
  141. assert torch.all(out_d_layer[0] == out_multi_dgca)
  142. class Dummy1(torch.nn.Module):
  143. def __init__(self, **kwargs):
  144. super().__init__(**kwargs)
  145. self.whatever = torch.nn.Parameter(torch.rand((10, 10)))
  146. class Dummy2(torch.nn.Module):
  147. def __init__(self, **kwargs):
  148. super().__init__(**kwargs)
  149. self.dummy_1 = Dummy1()
  150. class Dummy3(torch.nn.Module):
  151. def __init__(self, **kwargs):
  152. super().__init__(**kwargs)
  153. self.dummy_1 = [ Dummy1() ]
  154. class Dummy4(torch.nn.Module):
  155. def __init__(self, **kwargs):
  156. super().__init__(**kwargs)
  157. self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
  158. def test_module_nesting_01():
  159. device = torch.device('cuda:0')
  160. dummy_2 = Dummy2()
  161. dummy_2 = dummy_2.to(device)
  162. assert dummy_2.dummy_1.whatever.device == device
  163. def test_module_nesting_02():
  164. device = torch.device('cuda:0')
  165. dummy_3 = Dummy3()
  166. dummy_3 = dummy_3.to(device)
  167. assert dummy_3.dummy_1[0].whatever.device != device
  168. def test_module_nesting_03():
  169. device = torch.device('cuda:0')
  170. dummy_4 = Dummy4()
  171. dummy_4 = dummy_4.to(device)
  172. assert dummy_4.dummy_1[0].whatever.device == device