IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
4.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.input import OneHotInputLayer
  6. from icosagon.convlayer import DecagonLayer
  7. from icosagon.declayer import DecodeLayer, \
  8. Predictions, \
  9. RelationFamilyPredictions, \
  10. RelationPredictions
  11. from icosagon.decode import DEDICOMDecoder
  12. from icosagon.data import Data
  13. from icosagon.trainprep import prepare_training, \
  14. TrainValTest
  15. import torch
  16. def test_decode_layer_01():
  17. d = Data()
  18. d.add_node_type('Dummy', 100)
  19. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  20. fam.add_relation_type('Dummy Relation 1',
  21. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  22. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  23. in_layer = OneHotInputLayer(d)
  24. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  25. seq = torch.nn.Sequential(in_layer, d_layer)
  26. last_layer_repr = seq(None)
  27. dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  28. activation=lambda x: x)
  29. pred = dec(last_layer_repr)
  30. assert isinstance(pred, Predictions)
  31. assert isinstance(pred.relation_families, list)
  32. assert len(pred.relation_families) == 1
  33. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  34. assert isinstance(pred.relation_families[0].relation_types, list)
  35. assert len(pred.relation_families[0].relation_types) == 1
  36. assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
  37. tmp = pred.relation_families[0].relation_types[0]
  38. assert isinstance(tmp.edges_pos, TrainValTest)
  39. assert isinstance(tmp.edges_neg, TrainValTest)
  40. assert isinstance(tmp.edges_back_pos, TrainValTest)
  41. assert isinstance(tmp.edges_back_neg, TrainValTest)
  42. def test_decode_layer_02():
  43. d = Data()
  44. d.add_node_type('Dummy', 100)
  45. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  46. fam.add_relation_type('Dummy Relation 1',
  47. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  48. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  49. in_layer = OneHotInputLayer(d)
  50. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  51. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  52. keep_prob=1., activation=lambda x: x)
  53. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  54. pred = seq(None)
  55. assert isinstance(pred, Predictions)
  56. assert len(pred.relation_families) == 1
  57. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  58. assert isinstance(pred.relation_families[0].relation_types, list)
  59. assert len(pred.relation_families[0].relation_types) == 1
  60. def test_decode_layer_03():
  61. d = Data()
  62. d.add_node_type('Dummy 1', 100)
  63. d.add_node_type('Dummy 2', 100)
  64. fam = d.add_relation_family('Dummy 1-Dummy 2', 0, 1, True)
  65. fam.add_relation_type('Dummy Relation 1',
  66. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  67. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  68. in_layer = OneHotInputLayer(d)
  69. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  70. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  71. keep_prob=1., activation=lambda x: x)
  72. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  73. pred = seq(None)
  74. assert isinstance(pred, Predictions)
  75. assert len(pred.relation_families) == 1
  76. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  77. assert isinstance(pred.relation_families[0].relation_types, list)
  78. assert len(pred.relation_families[0].relation_types) == 1
  79. assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
  80. def test_decode_layer_04():
  81. d = Data()
  82. d.add_node_type('Dummy', 100)
  83. assert len(d.relation_families) == 0
  84. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  85. in_layer = OneHotInputLayer(d)
  86. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  87. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  88. keep_prob=1., activation=lambda x: x)
  89. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  90. pred = seq(None)
  91. assert isinstance(pred, Predictions)
  92. assert len(pred.relation_families) == 0