IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

test_trainprep.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.trainprep import TrainValTest, \
  6. train_val_test_split_edges, \
  7. get_edges_and_degrees, \
  8. prepare_adj_mat, \
  9. prepare_relation_type, \
  10. prep_rel_one_node_type, \
  11. prep_rel_two_node_types_sym, \
  12. prep_rel_two_node_types_asym
  13. import torch
  14. import pytest
  15. import numpy as np
  16. from itertools import chain
  17. from icosagon.data import RelationType
  18. import icosagon.trainprep
  19. def test_train_val_test_split_edges_01():
  20. edges = torch.randint(0, 10, (10, 2))
  21. with pytest.raises(ValueError):
  22. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  23. with pytest.raises(ValueError):
  24. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  25. with pytest.raises(ValueError):
  26. _ = train_val_test_split_edges(edges, None)
  27. with pytest.raises(ValueError):
  28. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  29. with pytest.raises(ValueError):
  30. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  31. with pytest.raises(ValueError):
  32. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  33. with pytest.raises(ValueError):
  34. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  35. with pytest.raises(ValueError):
  36. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  37. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  38. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  39. res.test.shape == (1, 2)
  40. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  41. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  42. res.test.shape == (2, 2)
  43. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  44. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  45. res.test.shape == (0, 2)
  46. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  47. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  48. res.test.shape == (5, 2)
  49. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  50. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  51. res.test.shape == (10, 2)
  52. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  53. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  54. res.test.shape == (0, 2)
  55. def test_train_val_test_split_edges_02():
  56. edges = torch.randint(0, 30, (30, 2))
  57. ratios = TrainValTest(.8, .1, .1)
  58. res = train_val_test_split_edges(edges, ratios)
  59. edges = [ tuple(a) for a in edges ]
  60. res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
  61. assert all([ a in edges for a in res ])
  62. def test_get_edges_and_degrees_01():
  63. adj_mat_dense = (torch.rand((10, 10)) > .5)
  64. adj_mat_sparse = adj_mat_dense.to_sparse()
  65. edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
  66. edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
  67. assert torch.all(degrees_dense == degrees_sparse)
  68. edges_dense = [ tuple(a) for a in edges_dense ]
  69. edges_sparse = [ tuple(a) for a in edges_dense ]
  70. assert len(edges_dense) == len(edges_sparse)
  71. assert all([ a in edges_dense for a in edges_sparse ])
  72. assert all([ a in edges_sparse for a in edges_dense ])
  73. # assert torch.all(edges_dense == edges_sparse)
  74. def test_prepare_adj_mat_01():
  75. adj_mat = (torch.rand((10, 10)) > .5)
  76. adj_mat = adj_mat.to_sparse()
  77. ratios = TrainValTest(.8, .1, .1)
  78. _ = prepare_adj_mat(adj_mat, ratios)
  79. def test_prepare_adj_mat_02():
  80. adj_mat = (torch.rand((10, 10)) > .5)
  81. adj_mat = adj_mat.to_sparse()
  82. ratios = TrainValTest(.8, .1, .1)
  83. (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
  84. assert isinstance(adj_mat_train, torch.Tensor)
  85. assert adj_mat_train.is_sparse
  86. assert adj_mat_train.shape == adj_mat.shape
  87. assert adj_mat_train.dtype == adj_mat.dtype
  88. assert isinstance(edges_pos, TrainValTest)
  89. assert isinstance(edges_neg, TrainValTest)
  90. for a in ['train', 'val', 'test']:
  91. for b in [edges_pos, edges_neg]:
  92. edges = getattr(b, a)
  93. assert isinstance(edges, torch.Tensor)
  94. assert len(edges.shape) == 2
  95. assert edges.shape[1] == 2
  96. def test_prepare_relation_type_01():
  97. adj_mat = (torch.rand((10, 10)) > .5)
  98. r = RelationType('Test', 0, 0, adj_mat, True)
  99. ratios = TrainValTest(.8, .1, .1)
  100. _ = prepare_relation_type(r, ratios, False)
  101. def test_prep_rel_one_node_type_01():
  102. adj_mat = torch.zeros(100)
  103. perm = torch.randperm(100)
  104. adj_mat[perm[:10]] = 1
  105. adj_mat = adj_mat.view(10, 10)
  106. rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
  107. ratios = TrainValTest(.8, .1, .1)
  108. prep_rel = prep_rel_one_node_type(rel, ratios)
  109. assert prep_rel.name == rel.name
  110. assert prep_rel.node_type_row == rel.node_type_row
  111. assert prep_rel.node_type_column == rel.node_type_column
  112. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  113. assert prep_rel.adjacency_matrix_backward is None
  114. assert len(prep_rel.edges_pos.train) == 8
  115. assert len(prep_rel.edges_pos.val) == 1
  116. assert len(prep_rel.edges_pos.test) == 1
  117. assert len(prep_rel.edges_neg.train) == 8
  118. assert len(prep_rel.edges_neg.val) == 1
  119. assert len(prep_rel.edges_neg.test) == 1
  120. assert len(prep_rel.edges_back_pos.train) == 0
  121. assert len(prep_rel.edges_back_pos.val) == 0
  122. assert len(prep_rel.edges_back_pos.test) == 0
  123. assert len(prep_rel.edges_back_neg.train) == 0
  124. assert len(prep_rel.edges_back_neg.val) == 0
  125. assert len(prep_rel.edges_back_neg.test) == 0
  126. def test_prep_rel_two_node_types_sym_01():
  127. adj_mat = torch.zeros(200)
  128. perm = torch.randperm(100)
  129. adj_mat[perm[:10]] = 1
  130. adj_mat = adj_mat.view(10, 20)
  131. rel = RelationType('Dummy Relation', 0, 1, adj_mat, None)
  132. ratios = TrainValTest(.8, .1, .1)
  133. prep_rel = prep_rel_two_node_types_sym(rel, ratios)
  134. assert prep_rel.name == rel.name
  135. assert prep_rel.node_type_row == rel.node_type_row
  136. assert prep_rel.node_type_column == rel.node_type_column
  137. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  138. assert prep_rel.adjacency_matrix_backward.shape == (20, 10)
  139. assert len(prep_rel.edges_pos.train) == 8
  140. assert len(prep_rel.edges_pos.val) == 1
  141. assert len(prep_rel.edges_pos.test) == 1
  142. assert len(prep_rel.edges_neg.train) == 8
  143. assert len(prep_rel.edges_neg.val) == 1
  144. assert len(prep_rel.edges_neg.test) == 1
  145. assert len(prep_rel.edges_back_pos.train) == 0
  146. assert len(prep_rel.edges_back_pos.val) == 0
  147. assert len(prep_rel.edges_back_pos.test) == 0
  148. assert len(prep_rel.edges_back_neg.train) == 0
  149. assert len(prep_rel.edges_back_neg.val) == 0
  150. assert len(prep_rel.edges_back_neg.test) == 0
  151. def test_prep_rel_two_node_types_asym_01():
  152. adj_mat = torch.zeros(200)
  153. perm = torch.randperm(100)
  154. adj_mat[perm[:10]] = 1
  155. adj_mat = adj_mat.view(10, 20)
  156. adj_mat_back = torch.zeros(200)
  157. perm = torch.randperm(100)
  158. adj_mat_back[perm[:10]] = 1
  159. adj_mat_back = adj_mat_back.view(20, 10)
  160. rel = RelationType('Dummy Relation', 0, 1, adj_mat, adj_mat_back)
  161. ratios = TrainValTest(.8, .1, .1)
  162. prep_rel = prep_rel_two_node_types_asym(rel, ratios)
  163. assert prep_rel.name == rel.name
  164. assert prep_rel.node_type_row == rel.node_type_row
  165. assert prep_rel.node_type_column == rel.node_type_column
  166. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  167. assert prep_rel.adjacency_matrix_backward.shape == rel.adjacency_matrix_backward.shape
  168. assert len(prep_rel.edges_pos.train) == 8
  169. assert len(prep_rel.edges_pos.val) == 1
  170. assert len(prep_rel.edges_pos.test) == 1
  171. assert len(prep_rel.edges_neg.train) == 8
  172. assert len(prep_rel.edges_neg.val) == 1
  173. assert len(prep_rel.edges_neg.test) == 1
  174. assert len(prep_rel.edges_back_pos.train) == 8
  175. assert len(prep_rel.edges_back_pos.val) == 1
  176. assert len(prep_rel.edges_back_pos.test) == 1
  177. assert len(prep_rel.edges_back_neg.train) == 8
  178. assert len(prep_rel.edges_back_neg.val) == 1
  179. assert len(prep_rel.edges_back_neg.test) == 1
  180. def test_prepare_relation_type_02():
  181. with pytest.raises(ValueError):
  182. prepare_relation_type(None, None, True)
  183. adj_mat = torch.rand(10, 10).round()
  184. rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
  185. with pytest.raises(ValueError):
  186. prepare_relation_type(rel, None, True)
  187. ratios = TrainValTest(.8, .1, .1)
  188. with pytest.raises(ValueError):
  189. prepare_relation_type(None, ratios, True)
  190. _ = prepare_relation_type(rel, ratios, True)
  191. def test_prepare_relation_type_03(monkeypatch):
  192. a = 0
  193. b = 0
  194. c = 0
  195. def fake_prep_rel_one_node_type(*args, **kwargs):
  196. nonlocal a
  197. a += 1
  198. def fake_prep_rel_two_node_types_sym(*args, **kwargs):
  199. nonlocal b
  200. b += 1
  201. def fake_prep_rel_two_node_types_asym(*args, **kwargs):
  202. nonlocal c
  203. c += 1
  204. monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type',
  205. fake_prep_rel_one_node_type)
  206. monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym',
  207. fake_prep_rel_two_node_types_sym)
  208. monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym',
  209. fake_prep_rel_two_node_types_asym)
  210. ratios = TrainValTest(.8, .1, .1)
  211. rel = RelationType('Dummy Relation', 0, 0, None, None)
  212. prepare_relation_type(rel, ratios, False)
  213. assert a == 1
  214. rel = RelationType('Dummy Relation', 0, 0, None, None)
  215. prepare_relation_type(rel, ratios, True)
  216. assert a == 2
  217. rel = RelationType('Dummy Relation', 0, 1, None, None)
  218. prepare_relation_type(rel, ratios, True)
  219. assert b == 1
  220. rel = RelationType('Dummy Relation', 0, 1, None, None)
  221. prepare_relation_type(rel, ratios, False)
  222. assert c == 1
  223. assert a == 2 and b == 1 and c == 1
  224. # def prepare_relation(r, ratios):
  225. # adj_mat = r.adjacency_matrix
  226. # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  227. #
  228. # if r.node_type_row == r.node_type_column:
  229. # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  230. # else:
  231. # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  232. #
  233. # return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  234. # adj_mat_train, edges_pos, edges_neg)