IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from icosagon.data import Data, \
  2. _equal
  3. from icosagon.model import Model
  4. from icosagon.trainprep import PreparedData, \
  5. PreparedRelationFamily, \
  6. PreparedRelationType, \
  7. TrainValTest, \
  8. norm_adj_mat_one_node_type
  9. import torch
  10. from icosagon.input import OneHotInputLayer
  11. from icosagon.convlayer import DecagonLayer
  12. from icosagon.declayer import DecodeLayer
  13. import pytest
  14. def _is_identity_function(f):
  15. for x in range(-100, 101):
  16. if f(x) != x:
  17. return False
  18. return True
  19. def test_model_01():
  20. d = Data()
  21. d.add_node_type('Dummy', 10)
  22. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  23. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  24. m = Model(d)
  25. assert m.data == d
  26. assert m.layer_dimensions == [32, 64]
  27. assert (m.ratios.train, m.ratios.val, m.ratios.test) == (.8, .1, .1)
  28. assert m.keep_prob == 1.
  29. assert _is_identity_function(m.rel_activation)
  30. assert m.layer_activation == torch.nn.functional.relu
  31. assert _is_identity_function(m.dec_activation)
  32. assert m.lr == 0.001
  33. assert m.loss == torch.nn.functional.binary_cross_entropy_with_logits
  34. assert m.batch_size == 100
  35. assert isinstance(m.prep_d, PreparedData)
  36. assert isinstance(m.seq, torch.nn.Sequential)
  37. assert isinstance(m.opt, torch.optim.Optimizer)
  38. def test_model_02():
  39. d = Data()
  40. d.add_node_type('Dummy', 10)
  41. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  42. mat = torch.rand(10, 10).round().to_sparse()
  43. fam.add_relation_type('Dummy Rel', mat)
  44. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  45. assert isinstance(m.prep_d, PreparedData)
  46. assert isinstance(m.prep_d.relation_families, list)
  47. assert len(m.prep_d.relation_families) == 1
  48. assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
  49. assert len(m.prep_d.relation_families[0].relation_types) == 1
  50. assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
  51. assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
  52. assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
  53. norm_adj_mat_one_node_type(mat)))
  54. assert isinstance(m.seq[0], OneHotInputLayer)
  55. assert isinstance(m.seq[1], DecagonLayer)
  56. assert isinstance(m.seq[2], DecagonLayer)
  57. assert isinstance(m.seq[3], DecodeLayer)
  58. assert len(m.seq) == 4
  59. def test_model_03():
  60. d = Data()
  61. d.add_node_type('Dummy', 10)
  62. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  63. mat = torch.rand(10, 10).round().to_sparse()
  64. fam.add_relation_type('Dummy Rel', mat)
  65. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  66. state_dict = m.opt.state_dict()
  67. assert isinstance(state_dict, dict)
  68. # print(state_dict['param_groups'])
  69. # print(list(m.seq.parameters()))
  70. assert len(list(m.seq[0].parameters())) == 1
  71. assert len(list(m.seq[1].parameters())) == 1
  72. assert len(list(m.seq[2].parameters())) == 1
  73. assert len(list(m.seq[3].parameters())) == 2
  74. # print(list(m.seq[1].parameters()))
  75. def test_model_04():
  76. d = Data()
  77. d.add_node_type('Dummy', 10)
  78. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  79. mat = torch.rand(10, 10).round().to_sparse()
  80. fam.add_relation_type('Dummy Rel 1', mat)
  81. fam.add_relation_type('Dummy Rel 2', mat.clone())
  82. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  83. assert len(list(m.seq[0].parameters())) == 1
  84. assert len(list(m.seq[1].parameters())) == 2
  85. assert len(list(m.seq[2].parameters())) == 2
  86. assert len(list(m.seq[3].parameters())) == 3
  87. def test_model_05():
  88. d = Data()
  89. d.add_node_type('Dummy', 10)
  90. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  91. mat = torch.rand(10, 10).round().to_sparse()
  92. fam.add_relation_type('Dummy Rel 1', mat)
  93. fam.add_relation_type('Dummy Rel 2', mat.clone())
  94. fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
  95. mat = torch.rand(10, 10).round().to_sparse()
  96. fam.add_relation_type('Dummy Rel 2-1', mat)
  97. fam.add_relation_type('Dummy Rel 2-2', mat.clone())
  98. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  99. assert len(list(m.seq[0].parameters())) == 1
  100. assert len(list(m.seq[1].parameters())) == 4
  101. assert len(list(m.seq[2].parameters())) == 4
  102. assert len(list(m.seq[3].parameters())) == 6
  103. def test_model_05():
  104. d = Data()
  105. d.add_node_type('Dummy', 10)
  106. d.add_node_type('Foobar', 20)
  107. fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
  108. mat = torch.rand(10, 20).round().to_sparse()
  109. fam.add_relation_type('Dummy Rel 1', mat)
  110. fam.add_relation_type('Dummy Rel 2', mat.clone())
  111. fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
  112. mat = torch.rand(10, 10).round().to_sparse()
  113. fam.add_relation_type('Dummy Rel 2-1', mat)
  114. fam.add_relation_type('Dummy Rel 2-2', mat.clone())
  115. m = Model(d, ratios=TrainValTest(1., 0., 0.))
  116. assert len(list(m.seq[0].parameters())) == 2
  117. assert len(list(m.seq[1].parameters())) == 6
  118. assert len(list(m.seq[2].parameters())) == 6
  119. assert len(list(m.seq[3].parameters())) == 6
  120. def test_model_06():
  121. d = Data()
  122. d.add_node_type('Dummy', 10)
  123. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  124. fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
  125. with pytest.raises(TypeError):
  126. m = Model(1)
  127. with pytest.raises(TypeError):
  128. m = Model(d, layer_dimensions=1)
  129. with pytest.raises(TypeError):
  130. m = Model(d, ratios=1)
  131. with pytest.raises(ValueError):
  132. m = Model(d, keep_prob='x')
  133. with pytest.raises(TypeError):
  134. m = Model(d, rel_activation='x')
  135. with pytest.raises(TypeError):
  136. m = Model(d, layer_activation='x')
  137. with pytest.raises(TypeError):
  138. m = Model(d, dec_activation='x')
  139. with pytest.raises(ValueError):
  140. m = Model(d, lr='x')
  141. with pytest.raises(TypeError):
  142. m = Model(d, loss=1)
  143. with pytest.raises(ValueError):
  144. m = Model(d, batch_size='x')