IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

241 行
9.7KB

  1. from icosagon.data import Data
  2. from icosagon.bulkdec import BulkDecodeLayer
  3. from icosagon.input import OneHotInputLayer
  4. from icosagon.convlayer import DecagonLayer
  5. import torch
  6. import pytest
  7. import time
  8. import sys
  9. def test_bulk_decode_layer_01():
  10. data = Data()
  11. data.add_node_type('Dummy', 100)
  12. fam = data.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Relation 1',
  14. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  15. in_layer = OneHotInputLayer(data)
  16. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  17. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  18. keep_prob=1., activation=lambda x: x)
  19. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  20. pred = seq(None)
  21. assert isinstance(pred, list)
  22. assert len(pred) == len(data.relation_families)
  23. assert isinstance(pred[0], torch.Tensor)
  24. assert len(pred[0].shape) == 3
  25. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  26. assert pred[0].shape[1] == data.node_types[0].count
  27. assert pred[0].shape[2] == data.node_types[0].count
  28. def test_bulk_decode_layer_02():
  29. data = Data()
  30. data.add_node_type('Foo', 100)
  31. data.add_node_type('Bar', 50)
  32. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  33. fam.add_relation_type('Foobar Relation 1',
  34. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  35. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  36. in_layer = OneHotInputLayer(data)
  37. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  38. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  39. keep_prob=1., activation=lambda x: x)
  40. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  41. pred = seq(None)
  42. assert isinstance(pred, list)
  43. assert len(pred) == len(data.relation_families)
  44. assert isinstance(pred[0], torch.Tensor)
  45. assert len(pred[0].shape) == 3
  46. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  47. assert pred[0].shape[1] == data.node_types[0].count
  48. assert pred[0].shape[2] == data.node_types[1].count
  49. def test_bulk_decode_layer_03():
  50. data = Data()
  51. data.add_node_type('Foo', 100)
  52. data.add_node_type('Bar', 50)
  53. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  54. fam.add_relation_type('Foobar Relation 1',
  55. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  56. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  57. fam.add_relation_type('Foobar Relation 2',
  58. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  59. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  60. in_layer = OneHotInputLayer(data)
  61. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  62. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  63. keep_prob=1., activation=lambda x: x)
  64. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  65. pred = seq(None)
  66. assert isinstance(pred, list)
  67. assert len(pred) == len(data.relation_families)
  68. assert isinstance(pred[0], torch.Tensor)
  69. assert len(pred[0].shape) == 3
  70. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  71. assert pred[0].shape[1] == data.node_types[0].count
  72. assert pred[0].shape[2] == data.node_types[1].count
  73. def test_bulk_decode_layer_03_big():
  74. data = Data()
  75. data.add_node_type('Foo', 2000)
  76. data.add_node_type('Bar', 2100)
  77. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  78. fam.add_relation_type('Foobar Relation 1',
  79. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  80. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  81. fam.add_relation_type('Foobar Relation 2',
  82. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  83. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  84. in_layer = OneHotInputLayer(data)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  86. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  87. keep_prob=1., activation=lambda x: x)
  88. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  89. pred = seq(None)
  90. assert isinstance(pred, list)
  91. assert len(pred) == len(data.relation_families)
  92. assert isinstance(pred[0], torch.Tensor)
  93. assert len(pred[0].shape) == 3
  94. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  95. assert pred[0].shape[1] == data.node_types[0].count
  96. assert pred[0].shape[2] == data.node_types[1].count
  97. def test_bulk_decode_layer_03_huge_gpu():
  98. if torch.cuda.device_count() == 0:
  99. pytest.skip('test_bulk_decode_layer_03_huge_gpu() requires CUDA support')
  100. device = torch.device('cuda:0')
  101. data = Data()
  102. data.add_node_type('Foo', 20000)
  103. data.add_node_type('Bar', 21000)
  104. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  105. print('Adding Foobar Relation 1...')
  106. fam.add_relation_type('Foobar Relation 1',
  107. torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
  108. torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
  109. print('Adding Foobar Relation 2...')
  110. fam.add_relation_type('Foobar Relation 2',
  111. torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
  112. torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
  113. in_layer = OneHotInputLayer(data)
  114. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  115. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  116. keep_prob=1., activation=lambda x: x)
  117. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  118. seq = seq.to(device)
  119. print('Starting forward pass...')
  120. t = time.time()
  121. pred = seq(None)
  122. print('Elapsed:', time.time() - t)
  123. assert isinstance(pred, list)
  124. assert len(pred) == len(data.relation_families)
  125. assert isinstance(pred[0], torch.Tensor)
  126. assert len(pred[0].shape) == 3
  127. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  128. assert pred[0].shape[1] == data.node_types[0].count
  129. assert pred[0].shape[2] == data.node_types[1].count
  130. def test_bulk_decode_layer_04_huge_multirel_gpu():
  131. if torch.cuda.device_count() == 0:
  132. pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires CUDA support')
  133. if torch.cuda.get_device_properties(0).total_memory < 64000000000:
  134. pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires GPU with 64GB of memory')
  135. device = torch.device('cuda:0')
  136. data = Data()
  137. data.add_node_type('Foo', 20000)
  138. data.add_node_type('Bar', 21000)
  139. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  140. print('Generating adj_mat ...')
  141. adj_mat = torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device)
  142. print('Generating adj_mat_back ...')
  143. adj_mat_back = torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device)
  144. print('Adding relations ...')
  145. for i in range(1300):
  146. sys.stdout.write('.')
  147. sys.stdout.flush()
  148. fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
  149. print()
  150. in_layer = OneHotInputLayer(data)
  151. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  152. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  153. keep_prob=1., activation=lambda x: x)
  154. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  155. seq = seq.to(device)
  156. print('Starting forward pass...')
  157. t = time.time()
  158. pred = seq(None)
  159. print('Elapsed:', time.time() - t)
  160. assert isinstance(pred, list)
  161. assert len(pred) == len(data.relation_families)
  162. assert isinstance(pred[0], torch.Tensor)
  163. assert len(pred[0].shape) == 3
  164. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  165. assert pred[0].shape[1] == data.node_types[0].count
  166. assert pred[0].shape[2] == data.node_types[1].count
  167. def test_bulk_decode_layer_04_big_multirel_gpu():
  168. if torch.cuda.device_count() == 0:
  169. pytest.skip('test_bulk_decode_layer_04_big_multirel_gpu() requires CUDA support')
  170. device = torch.device('cuda:0')
  171. data = Data()
  172. data.add_node_type('Foo', 2000)
  173. data.add_node_type('Bar', 2100)
  174. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  175. print('Generating adj_mat ...')
  176. adj_mat = torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse().to(device)
  177. print('Generating adj_mat_back ...')
  178. adj_mat_back = torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse().to(device)
  179. print('Adding relations ...')
  180. for i in range(1300):
  181. sys.stdout.write('.')
  182. sys.stdout.flush()
  183. fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
  184. print()
  185. in_layer = OneHotInputLayer(data)
  186. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  187. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  188. keep_prob=1., activation=lambda x: x)
  189. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  190. seq = seq.to(device)
  191. print('Starting forward pass...')
  192. t = time.time()
  193. pred = seq(None)
  194. print('Elapsed:', time.time() - t)
  195. assert isinstance(pred, list)
  196. assert len(pred) == len(data.relation_families)
  197. assert isinstance(pred[0], torch.Tensor)
  198. assert len(pred[0].shape) == 3
  199. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  200. assert pred[0].shape[1] == data.node_types[0].count
  201. assert pred[0].shape[2] == data.node_types[1].count