IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
9.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.input import OneHotInputLayer
  6. from icosagon.convolve import DropoutGraphConvActivation
  7. from icosagon.convlayer import DecagonLayer
  8. from icosagon.declayer import DecodeLayer, \
  9. Predictions, \
  10. RelationFamilyPredictions, \
  11. RelationPredictions
  12. from icosagon.decode import DEDICOMDecoder, \
  13. InnerProductDecoder
  14. from icosagon.data import Data
  15. from icosagon.trainprep import prepare_training, \
  16. TrainValTest
  17. import torch
  18. def test_decode_layer_01():
  19. d = Data()
  20. d.add_node_type('Dummy', 100)
  21. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  22. fam.add_relation_type('Dummy Relation 1',
  23. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  24. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  25. in_layer = OneHotInputLayer(d)
  26. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  27. seq = torch.nn.Sequential(in_layer, d_layer)
  28. last_layer_repr = seq(None)
  29. dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
  30. activation=lambda x: x)
  31. pred = dec(last_layer_repr)
  32. assert isinstance(pred, Predictions)
  33. assert isinstance(pred.relation_families, list)
  34. assert len(pred.relation_families) == 1
  35. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  36. assert isinstance(pred.relation_families[0].relation_types, list)
  37. assert len(pred.relation_families[0].relation_types) == 1
  38. assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
  39. tmp = pred.relation_families[0].relation_types[0]
  40. assert isinstance(tmp.edges_pos, TrainValTest)
  41. assert isinstance(tmp.edges_neg, TrainValTest)
  42. assert isinstance(tmp.edges_back_pos, TrainValTest)
  43. assert isinstance(tmp.edges_back_neg, TrainValTest)
  44. def test_decode_layer_02():
  45. d = Data()
  46. d.add_node_type('Dummy', 100)
  47. fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
  48. fam.add_relation_type('Dummy Relation 1',
  49. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  50. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  51. in_layer = OneHotInputLayer(d)
  52. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  53. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  54. keep_prob=1., activation=lambda x: x)
  55. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  56. pred = seq(None)
  57. assert isinstance(pred, Predictions)
  58. assert len(pred.relation_families) == 1
  59. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  60. assert isinstance(pred.relation_families[0].relation_types, list)
  61. assert len(pred.relation_families[0].relation_types) == 1
  62. def test_decode_layer_03():
  63. d = Data()
  64. d.add_node_type('Dummy 1', 100)
  65. d.add_node_type('Dummy 2', 100)
  66. fam = d.add_relation_family('Dummy 1-Dummy 2', 0, 1, True)
  67. fam.add_relation_type('Dummy Relation 1',
  68. torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
  69. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  70. in_layer = OneHotInputLayer(d)
  71. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  72. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  73. keep_prob=1., activation=lambda x: x)
  74. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  75. pred = seq(None)
  76. assert isinstance(pred, Predictions)
  77. assert len(pred.relation_families) == 1
  78. assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
  79. assert isinstance(pred.relation_families[0].relation_types, list)
  80. assert len(pred.relation_families[0].relation_types) == 1
  81. assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
  82. def test_decode_layer_04():
  83. d = Data()
  84. d.add_node_type('Dummy', 100)
  85. assert len(d.relation_families) == 0
  86. prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
  87. in_layer = OneHotInputLayer(d)
  88. d_layer = DecagonLayer(in_layer.output_dim, 32, d)
  89. dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
  90. keep_prob=1., activation=lambda x: x)
  91. seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
  92. pred = seq(None)
  93. assert isinstance(pred, Predictions)
  94. assert len(pred.relation_families) == 0
  95. def test_decode_layer_05():
  96. d = Data()
  97. d.add_node_type('Dummy', 10)
  98. mat = torch.rand((10, 10))
  99. mat = (mat + mat.transpose(0, 1)) / 2
  100. mat = mat.round()
  101. fam = d.add_relation_family('Dummy-Dummy', 0, 0, True,
  102. decoder_class=InnerProductDecoder)
  103. fam.add_relation_type('Dummy Rel', mat.to_sparse())
  104. prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
  105. in_layer = OneHotInputLayer(d)
  106. conv_layer = DecagonLayer(in_layer.output_dim, 32, prep_d,
  107. rel_activation=lambda x: x, layer_activation=lambda x: x)
  108. dec_layer = DecodeLayer(conv_layer.output_dim, prep_d,
  109. keep_prob=1., activation=lambda x: x)
  110. seq = torch.nn.Sequential(in_layer, conv_layer, dec_layer)
  111. pred = seq(None)
  112. rel_pred = pred.relation_families[0].relation_types[0]
  113. for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
  114. edge_pred = getattr(rel_pred, edge_type)
  115. assert isinstance(edge_pred, TrainValTest)
  116. for part_type in ['train', 'val', 'test']:
  117. part_pred = getattr(edge_pred, part_type)
  118. assert isinstance(part_pred, torch.Tensor)
  119. assert len(part_pred.shape) == 1
  120. print(edge_type, part_type, part_pred.shape)
  121. if (edge_type, part_type) not in [('edges_pos', 'train'), ('edges_neg', 'train')]:
  122. assert part_pred.shape[0] == 0
  123. else:
  124. assert part_pred.shape[0] > 0
  125. prep_rel = prep_d.relation_families[0].relation_types[0]
  126. assert len(rel_pred.edges_pos.train) == len(prep_rel.edges_pos.train)
  127. assert len(rel_pred.edges_neg.train) == len(prep_rel.edges_neg.train)
  128. assert len(prep_rel.edges_pos.train) == torch.sum(mat)
  129. # print('Predictions for positive edges:')
  130. # print(rel_pred.edges_pos.train)
  131. # print('Predictions for negative edges:')
  132. # print(rel_pred.edges_neg.train)
  133. repr_in = in_layer(None)
  134. assert isinstance(repr_in, list)
  135. assert len(repr_in) == 1
  136. assert isinstance(repr_in[0], torch.Tensor)
  137. assert torch.all(repr_in[0].to_dense() == torch.eye(10))
  138. assert len(conv_layer.next_layer_repr[0]) == 1
  139. assert len(conv_layer.next_layer_repr[0][0].convolutions) == 1
  140. assert conv_layer.rel_activation(0) == 0
  141. assert conv_layer.rel_activation(1) == 1
  142. assert conv_layer.rel_activation(-1) == -1
  143. assert conv_layer.layer_activation(0) == 0
  144. assert conv_layer.layer_activation(1) == 1
  145. assert conv_layer.layer_activation(-1) == -1
  146. graph_conv = conv_layer.next_layer_repr[0][0].convolutions[0]
  147. assert isinstance(graph_conv, DropoutGraphConvActivation)
  148. assert graph_conv.activation(0) == 0
  149. assert graph_conv.activation(1) == 1
  150. assert graph_conv.activation(-1) == -1
  151. weight = graph_conv.graph_conv.weight
  152. adj_mat = prep_d.relation_families[0].relation_types[0].adjacency_matrix
  153. repr_conv = torch.sparse.mm(repr_in[0], weight)
  154. repr_conv = torch.mm(adj_mat, repr_conv)
  155. repr_conv = torch.nn.functional.normalize(repr_conv, p=2, dim=1)
  156. repr_conv_expect = conv_layer(repr_in)[0]
  157. print('repr_conv:\n', repr_conv)
  158. # print(repr_conv_expect)
  159. assert torch.all(repr_conv == repr_conv_expect)
  160. assert repr_conv.shape[1] == 32
  161. dec = InnerProductDecoder(32, 1, keep_prob=1., activation=lambda x: x)
  162. x, y = torch.meshgrid(torch.arange(0, 10), torch.arange(0, 10))
  163. x = x.flatten()
  164. y = y.flatten()
  165. repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
  166. repr_dec_expect = repr_dec_expect.view(10, 10)
  167. repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
  168. # repr_dec = torch.flatten(repr_dec)
  169. # repr_dec -= torch.eye(10)
  170. assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
  171. repr_dec_expect = torch.zeros((10, 10))
  172. x = prep_d.relation_families[0].relation_types[0].edges_pos.train
  173. repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
  174. x = prep_d.relation_families[0].relation_types[0].edges_neg.train
  175. repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
  176. print(repr_dec)
  177. print(repr_dec_expect)
  178. repr_dec = torch.zeros((10, 10))
  179. x = prep_d.relation_families[0].relation_types[0].edges_pos.train
  180. repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
  181. x = prep_d.relation_families[0].relation_types[0].edges_neg.train
  182. repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
  183. assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
  184. #print(prep_rel.edges_pos.train)
  185. #print(prep_rel.edges_neg.train)
  186. # assert isinstance(edge_pred.train)
  187. # assert isinstance(rel_pred.edges_pos, TrainValTest)
  188. # assert isinstance(rel_pred.edges_neg, TrainValTest)
  189. # assert isinstance(rel_pred.edges_back_pos, TrainValTest)
  190. # assert isinstance(rel_pred.edges_back_neg, TrainValTest)