IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

241 lines
9.7KB

  1. from icosagon.data import Data
  2. from icosagon.bulkdec import BulkDecodeLayer
  3. from icosagon.input import OneHotInputLayer
  4. from icosagon.convlayer import DecagonLayer
  5. import torch
  6. import pytest
  7. import time
  8. import sys
  9. def test_bulk_decode_layer_01():
  10. data = Data()
  11. data.add_node_type('Dummy', 100)
  12. fam = data.add_relation_family('Dummy-Dummy', 0, 0, False)
  13. fam.add_relation_type('Dummy Relation 1',
  14. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  15. in_layer = OneHotInputLayer(data)
  16. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  17. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  18. keep_prob=1., activation=lambda x: x)
  19. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  20. pred = seq(None)
  21. assert isinstance(pred, list)
  22. assert len(pred) == len(data.relation_families)
  23. assert isinstance(pred[0], torch.Tensor)
  24. assert len(pred[0].shape) == 3
  25. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  26. assert pred[0].shape[1] == data.node_types[0].count
  27. assert pred[0].shape[2] == data.node_types[0].count
  28. def test_bulk_decode_layer_02():
  29. data = Data()
  30. data.add_node_type('Foo', 100)
  31. data.add_node_type('Bar', 50)
  32. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  33. fam.add_relation_type('Foobar Relation 1',
  34. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  35. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  36. in_layer = OneHotInputLayer(data)
  37. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  38. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  39. keep_prob=1., activation=lambda x: x)
  40. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  41. pred = seq(None)
  42. assert isinstance(pred, list)
  43. assert len(pred) == len(data.relation_families)
  44. assert isinstance(pred[0], torch.Tensor)
  45. assert len(pred[0].shape) == 3
  46. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  47. assert pred[0].shape[1] == data.node_types[0].count
  48. assert pred[0].shape[2] == data.node_types[1].count
  49. def test_bulk_decode_layer_03():
  50. data = Data()
  51. data.add_node_type('Foo', 100)
  52. data.add_node_type('Bar', 50)
  53. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  54. fam.add_relation_type('Foobar Relation 1',
  55. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  56. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  57. fam.add_relation_type('Foobar Relation 2',
  58. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  59. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  60. in_layer = OneHotInputLayer(data)
  61. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  62. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  63. keep_prob=1., activation=lambda x: x)
  64. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  65. pred = seq(None)
  66. assert isinstance(pred, list)
  67. assert len(pred) == len(data.relation_families)
  68. assert isinstance(pred[0], torch.Tensor)
  69. assert len(pred[0].shape) == 3
  70. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  71. assert pred[0].shape[1] == data.node_types[0].count
  72. assert pred[0].shape[2] == data.node_types[1].count
  73. def test_bulk_decode_layer_03_big():
  74. data = Data()
  75. data.add_node_type('Foo', 2000)
  76. data.add_node_type('Bar', 2100)
  77. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  78. fam.add_relation_type('Foobar Relation 1',
  79. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  80. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  81. fam.add_relation_type('Foobar Relation 2',
  82. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  83. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  84. in_layer = OneHotInputLayer(data)
  85. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  86. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  87. keep_prob=1., activation=lambda x: x)
  88. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  89. pred = seq(None)
  90. assert isinstance(pred, list)
  91. assert len(pred) == len(data.relation_families)
  92. assert isinstance(pred[0], torch.Tensor)
  93. assert len(pred[0].shape) == 3
  94. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  95. assert pred[0].shape[1] == data.node_types[0].count
  96. assert pred[0].shape[2] == data.node_types[1].count
  97. def test_bulk_decode_layer_03_huge_gpu():
  98. if torch.cuda.device_count() == 0:
  99. pytest.skip('test_bulk_decode_layer_03_huge_gpu() requires CUDA support')
  100. device = torch.device('cuda:0')
  101. data = Data()
  102. data.add_node_type('Foo', 20000)
  103. data.add_node_type('Bar', 21000)
  104. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  105. print('Adding Foobar Relation 1...')
  106. fam.add_relation_type('Foobar Relation 1',
  107. torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
  108. torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
  109. print('Adding Foobar Relation 2...')
  110. fam.add_relation_type('Foobar Relation 2',
  111. torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
  112. torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
  113. in_layer = OneHotInputLayer(data)
  114. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  115. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  116. keep_prob=1., activation=lambda x: x)
  117. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  118. seq = seq.to(device)
  119. print('Starting forward pass...')
  120. t = time.time()
  121. pred = seq(None)
  122. print('Elapsed:', time.time() - t)
  123. assert isinstance(pred, list)
  124. assert len(pred) == len(data.relation_families)
  125. assert isinstance(pred[0], torch.Tensor)
  126. assert len(pred[0].shape) == 3
  127. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  128. assert pred[0].shape[1] == data.node_types[0].count
  129. assert pred[0].shape[2] == data.node_types[1].count
  130. def test_bulk_decode_layer_04_huge_multirel_gpu():
  131. if torch.cuda.device_count() == 0:
  132. pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires CUDA support')
  133. if torch.cuda.get_device_properties(0).total_memory < 64000000000:
  134. pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires GPU with 64GB of memory')
  135. device = torch.device('cuda:0')
  136. data = Data()
  137. data.add_node_type('Foo', 20000)
  138. data.add_node_type('Bar', 21000)
  139. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  140. print('Generating adj_mat ...')
  141. adj_mat = torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device)
  142. print('Generating adj_mat_back ...')
  143. adj_mat_back = torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device)
  144. print('Adding relations ...')
  145. for i in range(1300):
  146. sys.stdout.write('.')
  147. sys.stdout.flush()
  148. fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
  149. print()
  150. in_layer = OneHotInputLayer(data)
  151. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  152. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  153. keep_prob=1., activation=lambda x: x)
  154. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  155. seq = seq.to(device)
  156. print('Starting forward pass...')
  157. t = time.time()
  158. pred = seq(None)
  159. print('Elapsed:', time.time() - t)
  160. assert isinstance(pred, list)
  161. assert len(pred) == len(data.relation_families)
  162. assert isinstance(pred[0], torch.Tensor)
  163. assert len(pred[0].shape) == 3
  164. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  165. assert pred[0].shape[1] == data.node_types[0].count
  166. assert pred[0].shape[2] == data.node_types[1].count
  167. def test_bulk_decode_layer_04_big_multirel_gpu():
  168. if torch.cuda.device_count() == 0:
  169. pytest.skip('test_bulk_decode_layer_04_big_multirel_gpu() requires CUDA support')
  170. device = torch.device('cuda:0')
  171. data = Data()
  172. data.add_node_type('Foo', 2000)
  173. data.add_node_type('Bar', 2100)
  174. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  175. print('Generating adj_mat ...')
  176. adj_mat = torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse().to(device)
  177. print('Generating adj_mat_back ...')
  178. adj_mat_back = torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse().to(device)
  179. print('Adding relations ...')
  180. for i in range(1300):
  181. sys.stdout.write('.')
  182. sys.stdout.flush()
  183. fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
  184. print()
  185. in_layer = OneHotInputLayer(data)
  186. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  187. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  188. keep_prob=1., activation=lambda x: x)
  189. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  190. seq = seq.to(device)
  191. print('Starting forward pass...')
  192. t = time.time()
  193. pred = seq(None)
  194. print('Elapsed:', time.time() - t)
  195. assert isinstance(pred, list)
  196. assert len(pred) == len(data.relation_families)
  197. assert isinstance(pred[0], torch.Tensor)
  198. assert len(pred[0].shape) == 3
  199. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  200. assert pred[0].shape[1] == data.node_types[0].count
  201. assert pred[0].shape[2] == data.node_types[1].count