IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

127 строки
4.5KB

  1. from icosagon.databatch import DataBatcher, \
  2. BatchedData, \
  3. BatchedDataPointer, \
  4. batched_data_skeleton
  5. from icosagon.data import Data
  6. from icosagon.trainprep import prepare_training, \
  7. TrainValTest
  8. from icosagon.declayer import DecodeLayer
  9. from icosagon.input import OneHotInputLayer
  10. import torch
  11. import time
  12. def _some_data():
  13. data = Data()
  14. data.add_node_type('Foo', 100)
  15. data.add_node_type('Bar', 500)
  16. fam = data.add_relation_family('Foo-Bar', 0, 1, True)
  17. adj_mat = torch.rand(100, 500).round().to_sparse()
  18. fam.add_relation_type('Foo-Bar', adj_mat)
  19. return data
  20. def _some_data_big():
  21. data = Data()
  22. data.add_node_type('Foo', 2000)
  23. data.add_node_type('Bar', 2100)
  24. fam = data.add_relation_family('Foo-Bar', 0, 1, True)
  25. adj_mat = torch.rand(2000, 2100).round().to_sparse()
  26. fam.add_relation_type('Foo-Bar', adj_mat)
  27. return data
  28. def test_data_batcher_01():
  29. data = _some_data()
  30. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  31. batcher = DataBatcher(prep_d, 512)
  32. def test_data_batcher_02():
  33. data = _some_data()
  34. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  35. batcher = DataBatcher(prep_d, 512)
  36. for batch_d in batcher:
  37. pass
  38. def test_data_batcher_03():
  39. data = _some_data()
  40. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  41. batcher = DataBatcher(prep_d, 512)
  42. for batch_d in batcher:
  43. edges_list = []
  44. for fam in batch_d.relation_families:
  45. for rel in fam.relation_types:
  46. for edge_type in ['edges_pos', 'edges_neg',
  47. 'edges_back_pos', 'edges_back_neg']:
  48. for part_type in ['train', 'val', 'test']:
  49. edges = getattr(getattr(rel, edge_type), part_type)
  50. edges_list.append(edges)
  51. assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
  52. def test_data_batcher_04():
  53. data = _some_data()
  54. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  55. batcher = DataBatcher(prep_d, 512)
  56. edges_list = []
  57. for batch_d in batcher:
  58. for fam in batch_d.relation_families:
  59. for rel in fam.relation_types:
  60. for edge_type in ['edges_pos', 'edges_neg',
  61. 'edges_back_pos', 'edges_back_neg']:
  62. for part_type in ['train', 'val', 'test']:
  63. edges = getattr(getattr(rel, edge_type), part_type)
  64. edges_list.append(edges)
  65. assert sum([ len(edges) for edges in edges_list ]) == \
  66. torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
  67. def test_data_batcher_05():
  68. data = _some_data()
  69. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  70. batcher = DataBatcher(prep_d, 512)
  71. for batch_d in batcher:
  72. edges_list = []
  73. for fam in batch_d.relation_families:
  74. for rel in fam.relation_types:
  75. for edge_type in ['edges_pos', 'edges_neg',
  76. 'edges_back_pos', 'edges_back_neg']:
  77. for part_type in ['train', 'val', 'test']:
  78. edges = getattr(getattr(rel, edge_type), part_type)
  79. edges_list.append(edges)
  80. assert all([ len(edges) <= 512 for edges in edges_list ])
  81. assert not all([ len(edges) == 0 for edges in edges_list ])
  82. print(sum(map(len, edges_list)))
  83. def test_batch_decode_01():
  84. data = _some_data()
  85. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  86. batcher = DataBatcher(prep_d, 512)
  87. ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
  88. in_repr = [ torch.rand(100, 32),
  89. torch.rand(500, 32) ]
  90. dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
  91. t = time.time()
  92. for batched_data in batcher:
  93. ptr.batched_data = batched_data
  94. _ = dec_layer(in_repr)
  95. print('Elapsed:', time.time() - t)
  96. def test_batch_decode_02():
  97. data = _some_data_big()
  98. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  99. batcher = DataBatcher(prep_d, 512)
  100. ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
  101. in_repr = [ torch.rand(2000, 32),
  102. torch.rand(2100, 32) ]
  103. dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
  104. t = time.time()
  105. for batched_data in batcher:
  106. ptr.batched_data = batched_data
  107. _ = dec_layer(in_repr)
  108. print('Elapsed:', time.time() - t)