IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
6.6KB

  1. from icosagon.batch import PredictionsBatch, \
  2. FlatPredictions, \
  3. flatten_predictions, \
  4. BatchIndices, \
  5. gather_batch_indices
  6. from icosagon.declayer import Predictions, \
  7. RelationPredictions, \
  8. RelationFamilyPredictions
  9. from icosagon.trainprep import prepare_training, \
  10. TrainValTest
  11. from icosagon.data import Data
  12. import torch
  13. import pytest
  14. def test_flat_predictions_01():
  15. pred = FlatPredictions(torch.tensor([0, 1, 0, 1]),
  16. torch.tensor([1, 0, 1, 0]), 'train')
  17. assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1]))
  18. assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0]))
  19. assert pred.part_type == 'train'
  20. def test_flatten_predictions_01():
  21. rel_pred = RelationPredictions(
  22. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  23. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  24. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  25. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  26. )
  27. fam_pred = RelationFamilyPredictions([ rel_pred ])
  28. pred = Predictions([ fam_pred ])
  29. pred_flat = flatten_predictions(pred, part_type='train')
  30. assert torch.all(pred_flat.predictions == \
  31. torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32))
  32. assert torch.all(pred_flat.truth == \
  33. torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32))
  34. assert pred_flat.part_type == 'train'
  35. def test_flatten_predictions_02():
  36. rel_pred = RelationPredictions(
  37. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  38. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  39. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  40. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  41. )
  42. fam_pred = RelationFamilyPredictions([ rel_pred ])
  43. pred = Predictions([ fam_pred ])
  44. pred_flat = flatten_predictions(pred, part_type='val')
  45. assert len(pred_flat.predictions) == 0
  46. assert len(pred_flat.truth) == 0
  47. assert pred_flat.part_type == 'val'
  48. def test_flatten_predictions_03():
  49. rel_pred = RelationPredictions(
  50. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  51. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  52. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  53. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  54. )
  55. fam_pred = RelationFamilyPredictions([ rel_pred ])
  56. pred = Predictions([ fam_pred ])
  57. pred_flat = flatten_predictions(pred, part_type='test')
  58. assert len(pred_flat.predictions) == 0
  59. assert len(pred_flat.truth) == 0
  60. assert pred_flat.part_type == 'test'
  61. def test_flatten_predictions_04():
  62. rel_pred = RelationPredictions(
  63. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  64. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  65. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  66. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  67. )
  68. fam_pred = RelationFamilyPredictions([ rel_pred ])
  69. pred = Predictions([ fam_pred ])
  70. with pytest.raises(TypeError):
  71. pred_flat = flatten_predictions(1, part_type='test')
  72. with pytest.raises(ValueError):
  73. pred_flat = flatten_predictions(pred, part_type='x')
  74. def test_batch_indices_01():
  75. indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
  76. assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
  77. assert indices.part_type == 'train'
  78. def test_gather_batch_indices_01():
  79. rel_pred = RelationPredictions(
  80. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  81. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  82. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  83. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  84. )
  85. fam_pred = RelationFamilyPredictions([ rel_pred ])
  86. pred = Predictions([ fam_pred ])
  87. pred_flat = flatten_predictions(pred, part_type='train')
  88. indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train')
  89. (input, target) = gather_batch_indices(pred_flat, indices)
  90. assert torch.all(input == \
  91. torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32))
  92. assert torch.all(target == \
  93. torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32))
  94. def test_predictions_batch_01():
  95. d = Data()
  96. d.add_node_type('Dummy', 5)
  97. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  98. fam.add_relation_type('Dummy Rel', torch.tensor([
  99. [0, 1, 0, 0, 0],
  100. [1, 0, 0, 0, 0],
  101. [0, 0, 0, 1, 0],
  102. [0, 0, 0, 0, 1],
  103. [0, 1, 0, 0, 0]
  104. ], dtype=torch.float32))
  105. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  106. assert len(prep_d.relation_families) == 1
  107. assert len(prep_d.relation_families[0].relation_types) == 1
  108. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
  109. assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
  110. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
  111. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
  112. rel_pred = RelationPredictions(
  113. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  114. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  115. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  116. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  117. )
  118. fam_pred = RelationFamilyPredictions([ rel_pred ])
  119. pred = Predictions([ fam_pred ])
  120. pred_flat = flatten_predictions(pred, part_type='train')
  121. batch = PredictionsBatch(prep_d, part_type='train', batch_size=1)
  122. count = 0
  123. lst = []
  124. for indices in batch:
  125. (input, target) = gather_batch_indices(pred_flat, indices)
  126. assert len(input) == 1
  127. assert len(target) == 1
  128. lst.append((input[0], target[0]))
  129. count += 1
  130. assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
  131. (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
  132. assert count == 10