IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

105 řádky
3.6KB

  1. from icosagon.data import Data, \
  2. _equal
  3. from icosagon.model import Model
  4. from icosagon.trainprep import PreparedData, \
  5. PreparedRelationFamily, \
  6. PreparedRelationType, \
  7. TrainValTest, \
  8. norm_adj_mat_one_node_type
  9. import torch
  10. from icosagon.input import OneHotInputLayer
  11. from icosagon.convlayer import DecagonLayer
  12. from icosagon.declayer import DecodeLayer
  13. def _is_identity_function(f):
  14. for x in range(-100, 101):
  15. if f(x) != x:
  16. return False
  17. return True
  18. def test_model_01():
  19. d = Data()
  20. d.add_node_type('Dummy', 10)
  21. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  22. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  23. m = Model(d)
  24. assert m.data == d
  25. assert m.layer_dimensions == [32, 64]
  26. assert (m.ratios.train, m.ratios.val, m.ratios.test) == (.8, .1, .1)
  27. assert m.keep_prob == 1.
  28. assert _is_identity_function(m.rel_activation)
  29. assert m.layer_activation == torch.nn.functional.relu
  30. assert _is_identity_function(m.dec_activation)
  31. assert m.lr == 0.001
  32. assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits
  33. assert m.batch_size == 100
  34. assert isinstance(m.prep_d, PreparedData)
  35. assert isinstance(m.seq, torch.nn.Sequential)
  36. assert isinstance(m.opt, torch.optim.Optimizer)
  37. def test_model_02():
  38. d = Data()
  39. d.add_node_type('Dummy', 10)
  40. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  41. mat = torch.rand(10, 10).round().to_sparse()
  42. fam.add_relation_type('Dummy Rel', mat)
  43. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  44. assert isinstance(m.prep_d, PreparedData)
  45. assert isinstance(m.prep_d.relation_families, list)
  46. assert len(m.prep_d.relation_families) == 1
  47. assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
  48. assert len(m.prep_d.relation_families[0].relation_types) == 1
  49. assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
  50. assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
  51. assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
  52. norm_adj_mat_one_node_type(mat)))
  53. assert isinstance(m.seq[0], OneHotInputLayer)
  54. assert isinstance(m.seq[1], DecagonLayer)
  55. assert isinstance(m.seq[2], DecagonLayer)
  56. assert isinstance(m.seq[3], DecodeLayer)
  57. assert len(m.seq) == 4
  58. def test_model_03():
  59. d = Data()
  60. d.add_node_type('Dummy', 10)
  61. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  62. mat = torch.rand(10, 10).round().to_sparse()
  63. fam.add_relation_type('Dummy Rel', mat)
  64. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  65. state_dict = m.opt.state_dict()
  66. assert isinstance(state_dict, dict)
  67. # print(state_dict['param_groups'])
  68. # print(list(m.seq.parameters()))
  69. assert len(list(m.seq[0].parameters())) == 1
  70. assert len(list(m.seq[1].parameters())) == 1
  71. assert len(list(m.seq[2].parameters())) == 1
  72. assert len(list(m.seq[3].parameters())) == 2
  73. # print(list(m.seq[1].parameters()))
  74. def test_model_04():
  75. d = Data()
  76. d.add_node_type('Dummy', 10)
  77. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  78. mat = torch.rand(10, 10).round().to_sparse()
  79. fam.add_relation_type('Dummy Rel 1', mat)
  80. fam.add_relation_type('Dummy Rel 2', mat.clone())
  81. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  82. assert len(list(m.seq[0].parameters())) == 1
  83. assert len(list(m.seq[1].parameters())) == 2
  84. assert len(list(m.seq[2].parameters())) == 2
  85. assert len(list(m.seq[3].parameters())) == 3