IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
1.7KB

  1. from triacontagon.loop import _merge_pos_neg_batches
  2. from triacontagon.model import TrainingBatch
  3. import torch
  4. import pytest
  5. def test_merge_pos_neg_batches_01():
  6. b_1 = TrainingBatch(0, 0, 0, torch.tensor([
  7. [0, 1],
  8. [2, 3],
  9. [4, 5],
  10. [5, 6]
  11. ]), torch.ones(4))
  12. b_2 = TrainingBatch(0, 0, 0, torch.tensor([
  13. [1, 6],
  14. [3, 5],
  15. [5, 2],
  16. [4, 1]
  17. ]), torch.zeros(4))
  18. b = _merge_pos_neg_batches(b_1, b_2)
  19. assert b.vertex_type_row == 0
  20. assert b.vertex_type_column == 0
  21. assert b.relation_type_index == 0
  22. assert torch.all(b.edges == torch.tensor([
  23. [0, 1],
  24. [2, 3],
  25. [4, 5],
  26. [5, 6],
  27. [1, 6],
  28. [3, 5],
  29. [5, 2],
  30. [4, 1]
  31. ]))
  32. assert torch.all(b.target_values == \
  33. torch.cat([ torch.ones(4), torch.zeros(4) ]))
  34. def test_merge_pos_neg_batches_02():
  35. b_1 = TrainingBatch(0, 1, 0, torch.tensor([
  36. [0, 1],
  37. [2, 3],
  38. [4, 5],
  39. [5, 6]
  40. ]), torch.ones(4))
  41. b_2 = TrainingBatch(0, 0, 0, torch.tensor([
  42. [1, 6],
  43. [3, 5],
  44. [5, 2],
  45. [4, 1]
  46. ]), torch.zeros(4))
  47. print(b_1)
  48. with pytest.raises(AssertionError):
  49. _ = _merge_pos_neg_batches(b_1, b_2)
  50. b_1.vertex_type_row, b_1.vertex_type_column = \
  51. b_1.vertex_type_column, b_1.vertex_type_row
  52. print(b_1)
  53. with pytest.raises(AssertionError):
  54. _ = _merge_pos_neg_batches(b_1, b_2)
  55. b_1.vertex_type_row, b_1.relation_type_index = \
  56. b_1.relation_type_index, b_1.vertex_type_row
  57. print(b_1)
  58. with pytest.raises(AssertionError):
  59. _ = _merge_pos_neg_batches(b_1, b_2)