IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

test_trainprep.py 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.trainprep import TrainValTest, \
  6. train_val_test_split_edges, \
  7. get_edges_and_degrees, \
  8. prepare_adj_mat, \
  9. prepare_relation_type, \
  10. prep_rel_one_node_type, \
  11. prep_rel_two_node_types_sym
  12. import torch
  13. import pytest
  14. import numpy as np
  15. from itertools import chain
  16. from icosagon.data import RelationType
  17. def test_train_val_test_split_edges_01():
  18. edges = torch.randint(0, 10, (10, 2))
  19. with pytest.raises(ValueError):
  20. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  21. with pytest.raises(ValueError):
  22. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  23. with pytest.raises(ValueError):
  24. _ = train_val_test_split_edges(edges, None)
  25. with pytest.raises(ValueError):
  26. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  27. with pytest.raises(ValueError):
  28. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  29. with pytest.raises(ValueError):
  30. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  31. with pytest.raises(ValueError):
  32. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  33. with pytest.raises(ValueError):
  34. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  35. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  36. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  37. res.test.shape == (1, 2)
  38. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  39. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  40. res.test.shape == (2, 2)
  41. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  42. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  43. res.test.shape == (0, 2)
  44. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  45. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  46. res.test.shape == (5, 2)
  47. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  48. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  49. res.test.shape == (10, 2)
  50. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  51. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  52. res.test.shape == (0, 2)
  53. def test_train_val_test_split_edges_02():
  54. edges = torch.randint(0, 30, (30, 2))
  55. ratios = TrainValTest(.8, .1, .1)
  56. res = train_val_test_split_edges(edges, ratios)
  57. edges = [ tuple(a) for a in edges ]
  58. res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
  59. assert all([ a in edges for a in res ])
  60. def test_get_edges_and_degrees_01():
  61. adj_mat_dense = (torch.rand((10, 10)) > .5)
  62. adj_mat_sparse = adj_mat_dense.to_sparse()
  63. edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
  64. edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
  65. assert torch.all(degrees_dense == degrees_sparse)
  66. edges_dense = [ tuple(a) for a in edges_dense ]
  67. edges_sparse = [ tuple(a) for a in edges_dense ]
  68. assert len(edges_dense) == len(edges_sparse)
  69. assert all([ a in edges_dense for a in edges_sparse ])
  70. assert all([ a in edges_sparse for a in edges_dense ])
  71. # assert torch.all(edges_dense == edges_sparse)
  72. def test_prepare_adj_mat_01():
  73. adj_mat = (torch.rand((10, 10)) > .5)
  74. adj_mat = adj_mat.to_sparse()
  75. ratios = TrainValTest(.8, .1, .1)
  76. _ = prepare_adj_mat(adj_mat, ratios)
  77. def test_prepare_adj_mat_02():
  78. adj_mat = (torch.rand((10, 10)) > .5)
  79. adj_mat = adj_mat.to_sparse()
  80. ratios = TrainValTest(.8, .1, .1)
  81. (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
  82. assert isinstance(adj_mat_train, torch.Tensor)
  83. assert adj_mat_train.is_sparse
  84. assert adj_mat_train.shape == adj_mat.shape
  85. assert adj_mat_train.dtype == adj_mat.dtype
  86. assert isinstance(edges_pos, TrainValTest)
  87. assert isinstance(edges_neg, TrainValTest)
  88. for a in ['train', 'val', 'test']:
  89. for b in [edges_pos, edges_neg]:
  90. edges = getattr(b, a)
  91. assert isinstance(edges, torch.Tensor)
  92. assert len(edges.shape) == 2
  93. assert edges.shape[1] == 2
  94. def test_prepare_relation_type_01():
  95. adj_mat = (torch.rand((10, 10)) > .5)
  96. r = RelationType('Test', 0, 0, adj_mat, True)
  97. ratios = TrainValTest(.8, .1, .1)
  98. _ = prepare_relation_type(r, ratios, False)
  99. def test_prep_rel_one_node_type_01():
  100. adj_mat = torch.zeros(100)
  101. perm = torch.randperm(100)
  102. adj_mat[perm[:10]] = 1
  103. adj_mat = adj_mat.view(10, 10)
  104. rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
  105. ratios = TrainValTest(.8, .1, .1)
  106. prep_rel = prep_rel_one_node_type(rel, ratios)
  107. assert prep_rel.name == rel.name
  108. assert prep_rel.node_type_row == rel.node_type_row
  109. assert prep_rel.node_type_column == rel.node_type_column
  110. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  111. assert prep_rel.adjacency_matrix_backward is None
  112. assert len(prep_rel.edges_pos.train) == 8
  113. assert len(prep_rel.edges_pos.val) == 1
  114. assert len(prep_rel.edges_pos.test) == 1
  115. assert len(prep_rel.edges_neg.train) == 8
  116. assert len(prep_rel.edges_neg.val) == 1
  117. assert len(prep_rel.edges_neg.test) == 1
  118. assert len(prep_rel.edges_back_pos.train) == 0
  119. assert len(prep_rel.edges_back_pos.val) == 0
  120. assert len(prep_rel.edges_back_pos.test) == 0
  121. assert len(prep_rel.edges_back_neg.train) == 0
  122. assert len(prep_rel.edges_back_neg.val) == 0
  123. assert len(prep_rel.edges_back_neg.test) == 0
  124. def test_prep_rel_two_node_types_sym_01():
  125. adj_mat = torch.zeros(200)
  126. perm = torch.randperm(100)
  127. adj_mat[perm[:10]] = 1
  128. adj_mat = adj_mat.view(10, 20)
  129. rel = RelationType('Dummy Relation', 0, 1, adj_mat, None)
  130. ratios = TrainValTest(.8, .1, .1)
  131. prep_rel = prep_rel_two_node_types_sym(rel, ratios)
  132. assert prep_rel.name == rel.name
  133. assert prep_rel.node_type_row == rel.node_type_row
  134. assert prep_rel.node_type_column == rel.node_type_column
  135. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  136. assert prep_rel.adjacency_matrix_backward.shape == (20, 10)
  137. assert len(prep_rel.edges_pos.train) == 8
  138. assert len(prep_rel.edges_pos.val) == 1
  139. assert len(prep_rel.edges_pos.test) == 1
  140. assert len(prep_rel.edges_neg.train) == 8
  141. assert len(prep_rel.edges_neg.val) == 1
  142. assert len(prep_rel.edges_neg.test) == 1
  143. assert len(prep_rel.edges_back_pos.train) == 0
  144. assert len(prep_rel.edges_back_pos.val) == 0
  145. assert len(prep_rel.edges_back_pos.test) == 0
  146. assert len(prep_rel.edges_back_neg.train) == 0
  147. assert len(prep_rel.edges_back_neg.val) == 0
  148. assert len(prep_rel.edges_back_neg.test) == 0
  149. # def prepare_relation(r, ratios):
  150. # adj_mat = r.adjacency_matrix
  151. # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  152. #
  153. # if r.node_type_row == r.node_type_column:
  154. # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  155. # else:
  156. # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  157. #
  158. # return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  159. # adj_mat_train, edges_pos, edges_neg)