IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

116 líneas
4.1KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.input import OneHotInputLayer
  6. from icosagon.convlayer import DecagonLayer
  7. from icosagon.declayer import DecodeLayer, \
  8. Predictions, \
  9. RelationFamilyPredictions, \
  10. RelationPredictions
  11. from icosagon.decode import DEDICOMDecoder
  12. from icosagon.data import Data
  13. from icosagon.trainprep import prepare_training, \
  14. TrainValTest
  15. import torch
  16. def test_decode_layer_01():
  17. d = Data()
  18. d.add_node_type('Dummy', 100)
  19. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  20. fam.add_relation_type('Dummy Relation 1',
  21. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  22. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  23. in_layer = OneHotInputLayer(d)
  24. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  25. seq = torch.nn.Sequential(in_layer, d_layer)
  26. last_layer_repr = seq(None)
  27. dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  28. activation=lambda x: x)
  29. pred = dec(last_layer_repr)
  30. assert isinstance(pred, Predictions)
  31. assert isinstance(pred.relation_families, list)
  32. assert len(pred.relation_families) == 1
  33. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  34. assert isinstance(pred.relation_families[0].relation_types, list)
  35. assert len(pred.relation_families[0].relation_types) == 1
  36. assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
  37. tmp = pred.relation_families[0].relation_types[0]
  38. assert isinstance(tmp.edges_pos, TrainValTest)
  39. assert isinstance(tmp.edges_neg, TrainValTest)
  40. assert isinstance(tmp.edges_back_pos, TrainValTest)
  41. assert isinstance(tmp.edges_back_neg, TrainValTest)
  42. def test_decode_layer_02():
  43. d = Data()
  44. d.add_node_type('Dummy', 100)
  45. d.add_relation_type('Dummy Relation 1', 0, 0,
  46. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  47. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  48. in_layer = OneHotInputLayer(d)
  49. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  50. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  51. decoder_class=DEDICOMDecoder, activation=lambda x: x)
  52. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  53. pred_adj_matrices = seq(None)
  54. assert isinstance(pred_adj_matrices, dict)
  55. assert len(pred_adj_matrices) == 1
  56. assert isinstance(pred_adj_matrices[0, 0], list)
  57. assert len(pred_adj_matrices[0, 0]) == 1
  58. def test_decode_layer_03():
  59. d = Data()
  60. d.add_node_type('Dummy 1', 100)
  61. d.add_node_type('Dummy 2', 100)
  62. d.add_relation_type('Dummy Relation 1', 0, 1,
  63. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  64. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  65. in_layer = OneHotInputLayer(d)
  66. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  67. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  68. decoder_class={(0, 1): DEDICOMDecoder}, activation=lambda x: x)
  69. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  70. pred_adj_matrices = seq(None)
  71. assert isinstance(pred_adj_matrices, dict)
  72. assert len(pred_adj_matrices) == 2
  73. assert isinstance(pred_adj_matrices[0, 1], list)
  74. assert isinstance(pred_adj_matrices[1, 0], list)
  75. assert len(pred_adj_matrices[0, 1]) == 1
  76. assert len(pred_adj_matrices[1, 0]) == 1
  77. def test_decode_layer_04():
  78. d = Data()
  79. d.add_node_type('Dummy', 100)
  80. assert len(d.relation_types[0, 0]) == 0
  81. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  82. in_layer = OneHotInputLayer(d)
  83. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  84. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  85. decoder_class=DEDICOMDecoder, activation=lambda x: x)
  86. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  87. pred_adj_matrices = seq(None)
  88. assert isinstance(pred_adj_matrices, dict)
  89. assert len(pred_adj_matrices) == 0