IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
7.0KB

  1. from icosagon.fastconv import _sparse_diag_cat, \
  2. _cat, \
  3. FastGraphConv, \
  4. FastConvLayer
  5. from icosagon.data import _equal
  6. import torch
  7. import pdb
  8. import time
  9. from icosagon.data import Data
  10. from icosagon.input import OneHotInputLayer
  11. from icosagon.convlayer import DecagonLayer
  12. def _make_symmetric(x: torch.Tensor):
  13. x = (x + x.transpose(0, 1)) / 2
  14. return x
  15. def _symmetric_random(n_rows, n_columns):
  16. return _make_symmetric(torch.rand((n_rows, n_columns),
  17. dtype=torch.float32).round().to_sparse())
  18. def _some_data_with_interactions():
  19. d = Data()
  20. d.add_node_type('Gene', 1000)
  21. d.add_node_type('Drug', 100)
  22. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  23. fam.add_relation_type('Target',
  24. torch.rand((100, 1000), dtype=torch.float32).round().to_sparse())
  25. fam = d.add_relation_family('Gene-Gene', 0, 0, True)
  26. fam.add_relation_type('Interaction',
  27. _symmetric_random(1000, 1000))
  28. fam = d.add_relation_family('Drug-Drug', 1, 1, True)
  29. fam.add_relation_type('Side Effect: Nausea',
  30. _symmetric_random(100, 100))
  31. fam.add_relation_type('Side Effect: Infertility',
  32. _symmetric_random(100, 100))
  33. fam.add_relation_type('Side Effect: Death',
  34. _symmetric_random(100, 100))
  35. return d
  36. def test_sparse_diag_cat_01():
  37. matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
  38. ground_truth = torch.zeros(35, 70)
  39. ground_truth[0:5, 0:10] = matrices[0]
  40. ground_truth[5:10, 10:20] = matrices[1]
  41. ground_truth[10:15, 20:30] = matrices[2]
  42. ground_truth[15:20, 30:40] = matrices[3]
  43. ground_truth[20:25, 40:50] = matrices[4]
  44. ground_truth[25:30, 50:60] = matrices[5]
  45. ground_truth[30:35, 60:70] = matrices[6]
  46. res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
  47. res = res.to_dense()
  48. assert torch.all(res == ground_truth)
  49. def test_sparse_diag_cat_02():
  50. x = [ torch.rand(5, 10).round() for _ in range(7) ]
  51. a = [ m.to_sparse() for m in x ]
  52. a = _sparse_diag_cat(a)
  53. b = torch.rand(70, 64)
  54. res = torch.sparse.mm(a, b)
  55. ground_truth = torch.zeros(35, 64)
  56. ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
  57. ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
  58. ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
  59. ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
  60. ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
  61. ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
  62. ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
  63. assert torch.all(res == ground_truth)
  64. def test_cat_01():
  65. matrices = [ torch.rand(5, 10) for _ in range(7) ]
  66. res = _cat(matrices)
  67. assert res.shape == (35, 10)
  68. assert not res.is_sparse
  69. ground_truth = torch.zeros(35, 10)
  70. for i in range(7):
  71. ground_truth[i*5:(i+1)*5, :] = matrices[i]
  72. assert torch.all(res == ground_truth)
  73. def test_cat_02():
  74. matrices = [ torch.rand(5, 10) for _ in range(7) ]
  75. ground_truth = torch.zeros(35, 10)
  76. for i in range(7):
  77. ground_truth[i*5:(i+1)*5, :] = matrices[i]
  78. res = _cat([ m.to_sparse() for m in matrices ])
  79. assert res.shape == (35, 10)
  80. assert res.is_sparse
  81. assert torch.all(res.to_dense() == ground_truth)
  82. def test_fast_graph_conv_01():
  83. # pdb.set_trace()
  84. adj_mats = [ torch.rand(10, 15).round().to_sparse() \
  85. for _ in range(23) ]
  86. fgc = FastGraphConv(32, 64, adj_mats)
  87. in_repr = torch.rand(15, 32)
  88. _ = fgc(in_repr)
  89. def test_fast_graph_conv_02():
  90. t = time.time()
  91. m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
  92. adj_mats = [ m for _ in range(1300) ]
  93. print('Generating adj_mats took:', time.time() - t)
  94. t = time.time()
  95. fgc = FastGraphConv(32, 64, adj_mats)
  96. print('FGC constructor took:', time.time() - t)
  97. in_repr = torch.rand(2000, 32)
  98. for _ in range(3):
  99. t = time.time()
  100. _ = fgc(in_repr)
  101. print('FGC forward pass took:', time.time() - t)
  102. def test_fast_graph_conv_03():
  103. adj_mat = [
  104. [ 0, 0, 1, 0, 1 ],
  105. [ 0, 1, 0, 1, 0 ],
  106. [ 1, 0, 1, 0, 0 ]
  107. ]
  108. in_repr = torch.rand(5, 32)
  109. adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
  110. fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ])
  111. out_repr = fgc(in_repr)
  112. assert out_repr.shape == (1, 3, 64)
  113. assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all()
  114. def test_fast_graph_conv_04():
  115. adj_mat = [
  116. [ 0, 0, 1, 0, 1 ],
  117. [ 0, 1, 0, 1, 0 ],
  118. [ 1, 0, 1, 0, 0 ]
  119. ]
  120. in_repr = torch.rand(5, 32)
  121. adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
  122. fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ])
  123. out_repr = fgc(in_repr)
  124. assert out_repr.shape == (2, 3, 64)
  125. adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2)
  126. adj_mat_1[0:3, 0:5] = adj_mat
  127. adj_mat_1[3:6, 5:10] = adj_mat
  128. res = torch.mm(in_repr, fgc.weights)
  129. res = torch.split(res, res.shape[1] // 2, dim=1)
  130. res = torch.cat(res)
  131. res = torch.mm(adj_mat_1, res)
  132. assert (res.view(2, 3, 64) == out_repr).all()
  133. def test_fast_conv_layer_01():
  134. d = _some_data_with_interactions()
  135. in_layer = OneHotInputLayer(d)
  136. d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
  137. seq_1 = torch.nn.Sequential(in_layer, d_layer)
  138. _ = seq_1(None)
  139. conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
  140. seq_2 = torch.nn.Sequential(in_layer, conv_layer)
  141. _ = seq_2(None)
  142. def test_fast_conv_layer_02():
  143. d = _some_data_with_interactions()
  144. in_layer = OneHotInputLayer(d)
  145. d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
  146. seq_1 = torch.nn.Sequential(in_layer, d_layer)
  147. out_repr_1 = seq_1(None)
  148. assert len(d_layer.next_layer_repr[0]) == 2
  149. assert len(d_layer.next_layer_repr[1]) == 2
  150. conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
  151. assert len(conv_layer.next_layer_repr[1]) == 2
  152. conv_layer.next_layer_repr[1][0].weights = torch.cat([
  153. d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight,
  154. ], dim=1)
  155. conv_layer.next_layer_repr[1][1].weights = torch.cat([
  156. d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight,
  157. d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight,
  158. d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight,
  159. ], dim=1)
  160. assert len(conv_layer.next_layer_repr[0]) == 2
  161. conv_layer.next_layer_repr[0][0].weights = torch.cat([
  162. d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight,
  163. ], dim=1)
  164. conv_layer.next_layer_repr[0][1].weights = torch.cat([
  165. d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight,
  166. ], dim=1)
  167. seq_2 = torch.nn.Sequential(in_layer, conv_layer)
  168. out_repr_2 = seq_2(None)
  169. assert len(out_repr_1) == len(out_repr_2)
  170. for i in range(len(out_repr_1)):
  171. assert torch.all(out_repr_1[i] == out_repr_2[i])