IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

114 wiersze
4.6KB

  1. from icosagon.data import Data
  2. from icosagon.bulkdec import BulkDecodeLayer
  3. from icosagon.input import OneHotInputLayer
  4. from icosagon.convlayer import DecagonLayer
  5. import torch
  6. def test_bulk_decode_layer_01():
  7. data = Data()
  8. data.add_node_type('Dummy', 100)
  9. fam = data.add_relation_family('Dummy-Dummy', 0, 0, False)
  10. fam.add_relation_type('Dummy Relation 1',
  11. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  12. in_layer = OneHotInputLayer(data)
  13. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  14. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  15. keep_prob=1., activation=lambda x: x)
  16. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  17. pred = seq(None)
  18. assert isinstance(pred, list)
  19. assert len(pred) == len(data.relation_families)
  20. assert isinstance(pred[0], torch.Tensor)
  21. assert len(pred[0].shape) == 3
  22. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  23. assert pred[0].shape[1] == data.node_types[0].count
  24. assert pred[0].shape[2] == data.node_types[0].count
  25. def test_bulk_decode_layer_02():
  26. data = Data()
  27. data.add_node_type('Foo', 100)
  28. data.add_node_type('Bar', 50)
  29. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  30. fam.add_relation_type('Foobar Relation 1',
  31. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  32. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  33. in_layer = OneHotInputLayer(data)
  34. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  35. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  36. keep_prob=1., activation=lambda x: x)
  37. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  38. pred = seq(None)
  39. assert isinstance(pred, list)
  40. assert len(pred) == len(data.relation_families)
  41. assert isinstance(pred[0], torch.Tensor)
  42. assert len(pred[0].shape) == 3
  43. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  44. assert pred[0].shape[1] == data.node_types[0].count
  45. assert pred[0].shape[2] == data.node_types[1].count
  46. def test_bulk_decode_layer_03():
  47. data = Data()
  48. data.add_node_type('Foo', 100)
  49. data.add_node_type('Bar', 50)
  50. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  51. fam.add_relation_type('Foobar Relation 1',
  52. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  53. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  54. fam.add_relation_type('Foobar Relation 2',
  55. torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
  56. torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
  57. in_layer = OneHotInputLayer(data)
  58. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  59. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  60. keep_prob=1., activation=lambda x: x)
  61. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  62. pred = seq(None)
  63. assert isinstance(pred, list)
  64. assert len(pred) == len(data.relation_families)
  65. assert isinstance(pred[0], torch.Tensor)
  66. assert len(pred[0].shape) == 3
  67. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  68. assert pred[0].shape[1] == data.node_types[0].count
  69. assert pred[0].shape[2] == data.node_types[1].count
  70. def test_bulk_decode_layer_03_big():
  71. data = Data()
  72. data.add_node_type('Foo', 2000)
  73. data.add_node_type('Bar', 2100)
  74. fam = data.add_relation_family('Foo-Bar', 0, 1, False)
  75. fam.add_relation_type('Foobar Relation 1',
  76. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  77. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  78. fam.add_relation_type('Foobar Relation 2',
  79. torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
  80. torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
  81. in_layer = OneHotInputLayer(data)
  82. d_layer = DecagonLayer(in_layer.output_dim, 32, data)
  83. dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
  84. keep_prob=1., activation=lambda x: x)
  85. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  86. pred = seq(None)
  87. assert isinstance(pred, list)
  88. assert len(pred) == len(data.relation_families)
  89. assert isinstance(pred[0], torch.Tensor)
  90. assert len(pred[0].shape) == 3
  91. assert len(pred[0]) == len(data.relation_families[0].relation_types)
  92. assert pred[0].shape[1] == data.node_types[0].count
  93. assert pred[0].shape[2] == data.node_types[1].count