IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
3.0KB

  1. from icosagon.databatch import DataBatcher, \
  2. BatchedData
  3. from icosagon.data import Data
  4. from icosagon.trainprep import prepare_training, \
  5. TrainValTest
  6. import torch
  7. def _some_data():
  8. data = Data()
  9. data.add_node_type('Foo', 100)
  10. data.add_node_type('Bar', 500)
  11. fam = data.add_relation_family('Foo-Bar', 0, 1, True)
  12. adj_mat = torch.rand(100, 500).round().to_sparse()
  13. fam.add_relation_type('Foo-Bar', adj_mat)
  14. return data
  15. def test_data_batcher_01():
  16. data = _some_data()
  17. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  18. batcher = DataBatcher(prep_d, 512)
  19. def test_data_batcher_02():
  20. data = _some_data()
  21. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  22. batcher = DataBatcher(prep_d, 512)
  23. for batch_d in batcher:
  24. pass
  25. def test_data_batcher_03():
  26. data = _some_data()
  27. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  28. batcher = DataBatcher(prep_d, 512)
  29. for batch_d in batcher:
  30. edges_list = []
  31. for fam in batch_d.relation_families:
  32. for rel in fam.relation_types:
  33. for edge_type in ['edges_pos', 'edges_neg',
  34. 'edges_back_pos', 'edges_back_neg']:
  35. for part_type in ['train', 'val', 'test']:
  36. edges = getattr(getattr(rel, edge_type), part_type)
  37. edges_list.append(edges)
  38. assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
  39. def test_data_batcher_04():
  40. data = _some_data()
  41. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  42. batcher = DataBatcher(prep_d, 512)
  43. edges_list = []
  44. for batch_d in batcher:
  45. for fam in batch_d.relation_families:
  46. for rel in fam.relation_types:
  47. for edge_type in ['edges_pos', 'edges_neg',
  48. 'edges_back_pos', 'edges_back_neg']:
  49. for part_type in ['train', 'val', 'test']:
  50. edges = getattr(getattr(rel, edge_type), part_type)
  51. edges_list.append(edges)
  52. assert sum([ len(edges) for edges in edges_list ]) == \
  53. torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
  54. def test_data_batcher_05():
  55. data = _some_data()
  56. prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
  57. batcher = DataBatcher(prep_d, 512)
  58. for batch_d in batcher:
  59. edges_list = []
  60. for fam in batch_d.relation_families:
  61. for rel in fam.relation_types:
  62. for edge_type in ['edges_pos', 'edges_neg',
  63. 'edges_back_pos', 'edges_back_neg']:
  64. for part_type in ['train', 'val', 'test']:
  65. edges = getattr(getattr(rel, edge_type), part_type)
  66. edges_list.append(edges)
  67. assert all([ len(edges) <= 512 for edges in edges_list ])
  68. assert not all([ len(edges) == 0 for edges in edges_list ])
  69. print(sum(map(len, edges_list)))