IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

188 line
7.4KB

  1. from icosagon.batch import PredictionsBatch, \
  2. FlatPredictions, \
  3. flatten_predictions, \
  4. BatchIndices, \
  5. gather_batch_indices
  6. from icosagon.declayer import Predictions, \
  7. RelationPredictions, \
  8. RelationFamilyPredictions
  9. from icosagon.trainprep import prepare_training, \
  10. TrainValTest
  11. from icosagon.data import Data
  12. import torch
  13. import pytest
  14. def test_flat_predictions_01():
  15. pred = FlatPredictions(torch.tensor([0, 1, 0, 1]),
  16. torch.tensor([1, 0, 1, 0]), 'train')
  17. assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1]))
  18. assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0]))
  19. assert pred.part_type == 'train'
  20. def test_flatten_predictions_01():
  21. rel_pred = RelationPredictions(
  22. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  23. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  24. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  25. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  26. )
  27. fam_pred = RelationFamilyPredictions([ rel_pred ])
  28. pred = Predictions([ fam_pred ])
  29. pred_flat = flatten_predictions(pred, part_type='train')
  30. assert torch.all(pred_flat.predictions == \
  31. torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32))
  32. assert torch.all(pred_flat.truth == \
  33. torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32))
  34. assert pred_flat.part_type == 'train'
  35. def test_flatten_predictions_02():
  36. rel_pred = RelationPredictions(
  37. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  38. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  39. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  40. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  41. )
  42. fam_pred = RelationFamilyPredictions([ rel_pred ])
  43. pred = Predictions([ fam_pred ])
  44. pred_flat = flatten_predictions(pred, part_type='val')
  45. assert len(pred_flat.predictions) == 0
  46. assert len(pred_flat.truth) == 0
  47. assert pred_flat.part_type == 'val'
  48. def test_flatten_predictions_03():
  49. rel_pred = RelationPredictions(
  50. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  51. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  52. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  53. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  54. )
  55. fam_pred = RelationFamilyPredictions([ rel_pred ])
  56. pred = Predictions([ fam_pred ])
  57. pred_flat = flatten_predictions(pred, part_type='test')
  58. assert len(pred_flat.predictions) == 0
  59. assert len(pred_flat.truth) == 0
  60. assert pred_flat.part_type == 'test'
  61. def test_flatten_predictions_04():
  62. rel_pred = RelationPredictions(
  63. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  64. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  65. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  66. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  67. )
  68. fam_pred = RelationFamilyPredictions([ rel_pred ])
  69. pred = Predictions([ fam_pred ])
  70. with pytest.raises(TypeError):
  71. pred_flat = flatten_predictions(1, part_type='test')
  72. with pytest.raises(ValueError):
  73. pred_flat = flatten_predictions(pred, part_type='x')
  74. def test_flatten_predictions_05():
  75. x = torch.rand(5000)
  76. y = torch.cat([ x, x ])
  77. z = torch.cat([ torch.ones(5000), torch.zeros(5000) ])
  78. rel_pred = RelationPredictions(
  79. TrainValTest(x, torch.zeros(0), torch.zeros(0)),
  80. TrainValTest(x, torch.zeros(0), torch.zeros(0)),
  81. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  82. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  83. )
  84. fam_pred = RelationFamilyPredictions([ rel_pred ])
  85. pred = Predictions([ fam_pred ])
  86. for _ in range(10):
  87. pred_flat = flatten_predictions(pred, part_type='train')
  88. assert torch.all(pred_flat.predictions == y)
  89. assert torch.all(pred_flat.truth == z)
  90. assert pred_flat.part_type == 'train'
  91. def test_batch_indices_01():
  92. indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
  93. assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
  94. assert indices.part_type == 'train'
  95. def test_gather_batch_indices_01():
  96. rel_pred = RelationPredictions(
  97. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  98. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  99. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  100. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  101. )
  102. fam_pred = RelationFamilyPredictions([ rel_pred ])
  103. pred = Predictions([ fam_pred ])
  104. pred_flat = flatten_predictions(pred, part_type='train')
  105. indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train')
  106. (input, target) = gather_batch_indices(pred_flat, indices)
  107. assert torch.all(input == \
  108. torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32))
  109. assert torch.all(target == \
  110. torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32))
  111. def test_predictions_batch_01():
  112. d = Data()
  113. d.add_node_type('Dummy', 5)
  114. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  115. fam.add_relation_type('Dummy Rel', torch.tensor([
  116. [0, 1, 0, 0, 0],
  117. [1, 0, 0, 0, 0],
  118. [0, 0, 0, 1, 0],
  119. [0, 0, 0, 0, 1],
  120. [0, 1, 0, 0, 0]
  121. ], dtype=torch.float32))
  122. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  123. assert len(prep_d.relation_families) == 1
  124. assert len(prep_d.relation_families[0].relation_types) == 1
  125. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
  126. assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
  127. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
  128. assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
  129. rel_pred = RelationPredictions(
  130. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  131. TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
  132. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
  133. TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
  134. )
  135. fam_pred = RelationFamilyPredictions([ rel_pred ])
  136. pred = Predictions([ fam_pred ])
  137. pred_flat = flatten_predictions(pred, part_type='train')
  138. batch = PredictionsBatch(prep_d, part_type='train', batch_size=1)
  139. count = 0
  140. lst = []
  141. for indices in batch:
  142. (input, target) = gather_batch_indices(pred_flat, indices)
  143. assert len(input) == 1
  144. assert len(target) == 1
  145. lst.append((input[0], target[0]))
  146. count += 1
  147. assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
  148. (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
  149. assert count == 10