IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
преди 4 години
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from icosagon.data import Data, \
  2. _equal
  3. from icosagon.model import Model
  4. from icosagon.trainprep import PreparedData, \
  5. PreparedRelationFamily, \
  6. PreparedRelationType, \
  7. TrainValTest, \
  8. norm_adj_mat_one_node_type, \
  9. prepare_training
  10. import torch
  11. from icosagon.input import OneHotInputLayer
  12. from icosagon.convlayer import DecagonLayer
  13. from icosagon.declayer import DecodeLayer
  14. import pytest
  15. def _is_identity_function(f):
  16. for x in range(-100, 101):
  17. if f(x) != x:
  18. return False
  19. return True
  20. def test_model_01():
  21. d = Data()
  22. d.add_node_type('Dummy', 10)
  23. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  24. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  25. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  26. m = Model(prep_d)
  27. assert m.prep_d == prep_d
  28. assert m.layer_dimensions == [32, 64]
  29. assert m.keep_prob == 1.
  30. assert _is_identity_function(m.rel_activation)
  31. assert m.layer_activation == torch.nn.functional.relu
  32. assert _is_identity_function(m.dec_activation)
  33. assert isinstance(m.seq, torch.nn.Sequential)
  34. def test_model_02():
  35. d = Data()
  36. d.add_node_type('Dummy', 10)
  37. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  38. mat = torch.rand(10, 10).round().to_sparse()
  39. fam.add_relation_type('Dummy Rel', mat)
  40. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  41. m = Model(prep_d)
  42. assert isinstance(m.prep_d, PreparedData)
  43. assert isinstance(m.prep_d.relation_families, list)
  44. assert len(m.prep_d.relation_families) == 1
  45. assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
  46. assert len(m.prep_d.relation_families[0].relation_types) == 1
  47. assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
  48. assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
  49. assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
  50. norm_adj_mat_one_node_type(mat)))
  51. assert isinstance(m.seq[0], OneHotInputLayer)
  52. assert isinstance(m.seq[1], DecagonLayer)
  53. assert isinstance(m.seq[2], DecagonLayer)
  54. assert isinstance(m.seq[3], DecodeLayer)
  55. assert len(m.seq) == 4
  56. def test_model_03():
  57. d = Data()
  58. d.add_node_type('Dummy', 10)
  59. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  60. mat = torch.rand(10, 10).round().to_sparse()
  61. fam.add_relation_type('Dummy Rel', mat)
  62. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  63. m = Model(prep_d)
  64. assert len(list(m.seq[0].parameters())) == 1
  65. assert len(list(m.seq[1].parameters())) == 1
  66. assert len(list(m.seq[2].parameters())) == 1
  67. assert len(list(m.seq[3].parameters())) == 2
  68. def test_model_04():
  69. d = Data()
  70. d.add_node_type('Dummy', 10)
  71. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  72. mat = torch.rand(10, 10).round().to_sparse()
  73. fam.add_relation_type('Dummy Rel 1', mat)
  74. fam.add_relation_type('Dummy Rel 2', mat.clone())
  75. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  76. m = Model(prep_d)
  77. assert len(list(m.seq[0].parameters())) == 1
  78. assert len(list(m.seq[1].parameters())) == 2
  79. assert len(list(m.seq[2].parameters())) == 2
  80. assert len(list(m.seq[3].parameters())) == 3
  81. def test_model_05():
  82. d = Data()
  83. d.add_node_type('Dummy', 10)
  84. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  85. mat = torch.rand(10, 10).round().to_sparse()
  86. fam.add_relation_type('Dummy Rel 1', mat)
  87. fam.add_relation_type('Dummy Rel 2', mat.clone())
  88. fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
  89. mat = torch.rand(10, 10).round().to_sparse()
  90. fam.add_relation_type('Dummy Rel 2-1', mat)
  91. fam.add_relation_type('Dummy Rel 2-2', mat.clone())
  92. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  93. m = Model(prep_d)
  94. assert len(list(m.seq[0].parameters())) == 1
  95. assert len(list(m.seq[1].parameters())) == 4
  96. assert len(list(m.seq[2].parameters())) == 4
  97. assert len(list(m.seq[3].parameters())) == 6
  98. def test_model_06():
  99. d = Data()
  100. d.add_node_type('Dummy', 10)
  101. d.add_node_type('Foobar', 20)
  102. fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
  103. mat = torch.rand(10, 20).round().to_sparse()
  104. fam.add_relation_type('Dummy Rel 1', mat)
  105. fam.add_relation_type('Dummy Rel 2', mat.clone())
  106. fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
  107. mat = torch.rand(10, 10).round().to_sparse()
  108. fam.add_relation_type('Dummy Rel 2-1', mat)
  109. fam.add_relation_type('Dummy Rel 2-2', mat.clone())
  110. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  111. m = Model(prep_d)
  112. assert len(list(m.seq[0].parameters())) == 2
  113. assert len(list(m.seq[1].parameters())) == 6
  114. assert len(list(m.seq[2].parameters())) == 6
  115. assert len(list(m.seq[3].parameters())) == 6
  116. def test_model_07():
  117. d = Data()
  118. d.add_node_type('Dummy', 10)
  119. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  120. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  121. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  122. with pytest.raises(TypeError):
  123. m = Model(1)
  124. with pytest.raises(TypeError):
  125. m = Model(prep_d, layer_dimensions=1)
  126. with pytest.raises(TypeError):
  127. m = Model(prep_d, ratios=1)
  128. with pytest.raises(ValueError):
  129. m = Model(prep_d, keep_prob='x')
  130. with pytest.raises(TypeError):
  131. m = Model(prep_d, rel_activation='x')
  132. with pytest.raises(TypeError):
  133. m = Model(prep_d, layer_activation='x')
  134. with pytest.raises(TypeError):
  135. m = Model(prep_d, dec_activation='x')
  136. def test_model_08():
  137. d = Data()
  138. d.add_node_type('Dummy', 10)
  139. d.add_node_type('Foobar', 20)
  140. fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
  141. mat = torch.rand(10, 20).round().to_sparse()
  142. fam.add_relation_type('Dummy Rel 1', mat)
  143. fam.add_relation_type('Dummy Rel 2', mat.clone())
  144. fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
  145. mat = torch.rand(10, 10).round().to_sparse()
  146. fam.add_relation_type('Dummy Rel 2-1', mat)
  147. fam.add_relation_type('Dummy Rel 2-2', mat.clone())
  148. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  149. m = Model(prep_d)
  150. assert len(list(m.parameters())) == 20