IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
3.6KB

  1. from triacontagon.loop import _merge_pos_neg_batches, \
  2. TrainLoop
  3. from triacontagon.model import TrainingBatch, \
  4. Model
  5. from triacontagon.data import Data
  6. from triacontagon.decode import dedicom_decoder
  7. from triacontagon.util import common_one_hot_encoding
  8. from triacontagon.split import split_data
  9. import torch
  10. import pytest
  11. def test_merge_pos_neg_batches_01():
  12. b_1 = TrainingBatch(0, 0, 0, torch.tensor([
  13. [0, 1],
  14. [2, 3],
  15. [4, 5],
  16. [5, 6]
  17. ]), torch.ones(4))
  18. b_2 = TrainingBatch(0, 0, 0, torch.tensor([
  19. [1, 6],
  20. [3, 5],
  21. [5, 2],
  22. [4, 1]
  23. ]), torch.zeros(4))
  24. b = _merge_pos_neg_batches(b_1, b_2)
  25. assert b.vertex_type_row == 0
  26. assert b.vertex_type_column == 0
  27. assert b.relation_type_index == 0
  28. assert torch.all(b.edges == torch.tensor([
  29. [0, 1],
  30. [2, 3],
  31. [4, 5],
  32. [5, 6],
  33. [1, 6],
  34. [3, 5],
  35. [5, 2],
  36. [4, 1]
  37. ]))
  38. assert torch.all(b.target_values == \
  39. torch.cat([ torch.ones(4), torch.zeros(4) ]))
  40. def test_merge_pos_neg_batches_02():
  41. b_1 = TrainingBatch(0, 1, 0, torch.tensor([
  42. [0, 1],
  43. [2, 3],
  44. [4, 5],
  45. [5, 6]
  46. ]), torch.ones(4))
  47. b_2 = TrainingBatch(0, 0, 0, torch.tensor([
  48. [1, 6],
  49. [3, 5],
  50. [5, 2],
  51. [4, 1]
  52. ]), torch.zeros(4))
  53. print(b_1)
  54. with pytest.raises(AssertionError):
  55. _ = _merge_pos_neg_batches(b_1, b_2)
  56. b_1.vertex_type_row, b_1.vertex_type_column = \
  57. b_1.vertex_type_column, b_1.vertex_type_row
  58. print(b_1)
  59. with pytest.raises(AssertionError):
  60. _ = _merge_pos_neg_batches(b_1, b_2)
  61. b_1.vertex_type_row, b_1.relation_type_index = \
  62. b_1.relation_type_index, b_1.vertex_type_row
  63. print(b_1)
  64. with pytest.raises(AssertionError):
  65. _ = _merge_pos_neg_batches(b_1, b_2)
  66. def test_train_loop_01():
  67. data = Data()
  68. data.add_vertex_type('Foo', 5)
  69. data.add_vertex_type('Bar', 4)
  70. foo_foo = torch.tensor([
  71. [0, 0, 1, 0, 0],
  72. [0, 0, 0, 1, 0],
  73. [1, 0, 0, 0, 0],
  74. [0, 1, 0, 0, 0],
  75. [0, 0, 0, 0, 0]
  76. ], dtype=torch.float32)
  77. foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
  78. foo_bar = torch.tensor([
  79. [0, 0, 1, 0],
  80. [0, 0, 0, 1],
  81. [0, 1, 0, 0],
  82. [1, 0, 0, 0],
  83. [0, 0, 0, 1]
  84. ], dtype=torch.float32)
  85. bar_foo = foo_bar.transpose(0, 1)
  86. bar_bar = torch.tensor([
  87. [0, 1, 0, 0],
  88. [1, 0, 0, 0],
  89. [0, 0, 0, 1],
  90. [0, 0, 1, 0],
  91. ], dtype=torch.float32)
  92. bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
  93. data.add_edge_type('Foo-Foo', 0, 0, [
  94. foo_foo.to_sparse().coalesce()
  95. ], dedicom_decoder)
  96. data.add_edge_type('Foo-Bar', 0, 1, [
  97. foo_bar.to_sparse().coalesce()
  98. ], dedicom_decoder)
  99. data.add_edge_type('Bar-Foo', 1, 0, [
  100. bar_foo.to_sparse().coalesce()
  101. ], dedicom_decoder)
  102. data.add_edge_type('Bar-Bar', 1, 1, [
  103. bar_bar.to_sparse().coalesce()
  104. ], dedicom_decoder)
  105. initial_repr = common_one_hot_encoding([5, 4])
  106. model = Model(data, [9, 3, 6],
  107. keep_prob=1.0,
  108. conv_activation=torch.sigmoid,
  109. dec_activation=torch.sigmoid)
  110. train_data, val_data, test_data = split_data(data, (.5, .5, .0) )
  111. print('val_data:', val_data)
  112. print('val_data.vertex_types:', val_data.vertex_types)
  113. loop = TrainLoop(model, val_data, test_data, initial_repr,
  114. max_epochs=1, batch_size=1)
  115. _ = loop.run()