IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
8.8KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.trainprep import TrainValTest, \
  6. train_val_test_split_edges, \
  7. get_edges_and_degrees, \
  8. prepare_adj_mat, \
  9. prepare_relation_type, \
  10. prep_rel_one_node_type, \
  11. prep_rel_two_node_types_sym, \
  12. prep_rel_two_node_types_asym
  13. import torch
  14. import pytest
  15. import numpy as np
  16. from itertools import chain
  17. from icosagon.data import RelationType
  18. def test_train_val_test_split_edges_01():
  19. edges = torch.randint(0, 10, (10, 2))
  20. with pytest.raises(ValueError):
  21. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  22. with pytest.raises(ValueError):
  23. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  24. with pytest.raises(ValueError):
  25. _ = train_val_test_split_edges(edges, None)
  26. with pytest.raises(ValueError):
  27. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  28. with pytest.raises(ValueError):
  29. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  30. with pytest.raises(ValueError):
  31. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  32. with pytest.raises(ValueError):
  33. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  34. with pytest.raises(ValueError):
  35. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  36. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  37. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  38. res.test.shape == (1, 2)
  39. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  40. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  41. res.test.shape == (2, 2)
  42. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  43. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  44. res.test.shape == (0, 2)
  45. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  46. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  47. res.test.shape == (5, 2)
  48. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  49. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  50. res.test.shape == (10, 2)
  51. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  52. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  53. res.test.shape == (0, 2)
  54. def test_train_val_test_split_edges_02():
  55. edges = torch.randint(0, 30, (30, 2))
  56. ratios = TrainValTest(.8, .1, .1)
  57. res = train_val_test_split_edges(edges, ratios)
  58. edges = [ tuple(a) for a in edges ]
  59. res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
  60. assert all([ a in edges for a in res ])
  61. def test_get_edges_and_degrees_01():
  62. adj_mat_dense = (torch.rand((10, 10)) > .5)
  63. adj_mat_sparse = adj_mat_dense.to_sparse()
  64. edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
  65. edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
  66. assert torch.all(degrees_dense == degrees_sparse)
  67. edges_dense = [ tuple(a) for a in edges_dense ]
  68. edges_sparse = [ tuple(a) for a in edges_dense ]
  69. assert len(edges_dense) == len(edges_sparse)
  70. assert all([ a in edges_dense for a in edges_sparse ])
  71. assert all([ a in edges_sparse for a in edges_dense ])
  72. # assert torch.all(edges_dense == edges_sparse)
  73. def test_prepare_adj_mat_01():
  74. adj_mat = (torch.rand((10, 10)) > .5)
  75. adj_mat = adj_mat.to_sparse()
  76. ratios = TrainValTest(.8, .1, .1)
  77. _ = prepare_adj_mat(adj_mat, ratios)
  78. def test_prepare_adj_mat_02():
  79. adj_mat = (torch.rand((10, 10)) > .5)
  80. adj_mat = adj_mat.to_sparse()
  81. ratios = TrainValTest(.8, .1, .1)
  82. (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
  83. assert isinstance(adj_mat_train, torch.Tensor)
  84. assert adj_mat_train.is_sparse
  85. assert adj_mat_train.shape == adj_mat.shape
  86. assert adj_mat_train.dtype == adj_mat.dtype
  87. assert isinstance(edges_pos, TrainValTest)
  88. assert isinstance(edges_neg, TrainValTest)
  89. for a in ['train', 'val', 'test']:
  90. for b in [edges_pos, edges_neg]:
  91. edges = getattr(b, a)
  92. assert isinstance(edges, torch.Tensor)
  93. assert len(edges.shape) == 2
  94. assert edges.shape[1] == 2
  95. def test_prepare_relation_type_01():
  96. adj_mat = (torch.rand((10, 10)) > .5)
  97. r = RelationType('Test', 0, 0, adj_mat, True)
  98. ratios = TrainValTest(.8, .1, .1)
  99. _ = prepare_relation_type(r, ratios, False)
  100. def test_prep_rel_one_node_type_01():
  101. adj_mat = torch.zeros(100)
  102. perm = torch.randperm(100)
  103. adj_mat[perm[:10]] = 1
  104. adj_mat = adj_mat.view(10, 10)
  105. rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
  106. ratios = TrainValTest(.8, .1, .1)
  107. prep_rel = prep_rel_one_node_type(rel, ratios)
  108. assert prep_rel.name == rel.name
  109. assert prep_rel.node_type_row == rel.node_type_row
  110. assert prep_rel.node_type_column == rel.node_type_column
  111. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  112. assert prep_rel.adjacency_matrix_backward is None
  113. assert len(prep_rel.edges_pos.train) == 8
  114. assert len(prep_rel.edges_pos.val) == 1
  115. assert len(prep_rel.edges_pos.test) == 1
  116. assert len(prep_rel.edges_neg.train) == 8
  117. assert len(prep_rel.edges_neg.val) == 1
  118. assert len(prep_rel.edges_neg.test) == 1
  119. assert len(prep_rel.edges_back_pos.train) == 0
  120. assert len(prep_rel.edges_back_pos.val) == 0
  121. assert len(prep_rel.edges_back_pos.test) == 0
  122. assert len(prep_rel.edges_back_neg.train) == 0
  123. assert len(prep_rel.edges_back_neg.val) == 0
  124. assert len(prep_rel.edges_back_neg.test) == 0
  125. def test_prep_rel_two_node_types_sym_01():
  126. adj_mat = torch.zeros(200)
  127. perm = torch.randperm(100)
  128. adj_mat[perm[:10]] = 1
  129. adj_mat = adj_mat.view(10, 20)
  130. rel = RelationType('Dummy Relation', 0, 1, adj_mat, None)
  131. ratios = TrainValTest(.8, .1, .1)
  132. prep_rel = prep_rel_two_node_types_sym(rel, ratios)
  133. assert prep_rel.name == rel.name
  134. assert prep_rel.node_type_row == rel.node_type_row
  135. assert prep_rel.node_type_column == rel.node_type_column
  136. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  137. assert prep_rel.adjacency_matrix_backward.shape == (20, 10)
  138. assert len(prep_rel.edges_pos.train) == 8
  139. assert len(prep_rel.edges_pos.val) == 1
  140. assert len(prep_rel.edges_pos.test) == 1
  141. assert len(prep_rel.edges_neg.train) == 8
  142. assert len(prep_rel.edges_neg.val) == 1
  143. assert len(prep_rel.edges_neg.test) == 1
  144. assert len(prep_rel.edges_back_pos.train) == 0
  145. assert len(prep_rel.edges_back_pos.val) == 0
  146. assert len(prep_rel.edges_back_pos.test) == 0
  147. assert len(prep_rel.edges_back_neg.train) == 0
  148. assert len(prep_rel.edges_back_neg.val) == 0
  149. assert len(prep_rel.edges_back_neg.test) == 0
  150. def test_prep_rel_two_node_types_asym_01():
  151. adj_mat = torch.zeros(200)
  152. perm = torch.randperm(100)
  153. adj_mat[perm[:10]] = 1
  154. adj_mat = adj_mat.view(10, 20)
  155. adj_mat_back = torch.zeros(200)
  156. perm = torch.randperm(100)
  157. adj_mat_back[perm[:10]] = 1
  158. adj_mat_back = adj_mat_back.view(20, 10)
  159. rel = RelationType('Dummy Relation', 0, 1, adj_mat, adj_mat_back)
  160. ratios = TrainValTest(.8, .1, .1)
  161. prep_rel = prep_rel_two_node_types_asym(rel, ratios)
  162. assert prep_rel.name == rel.name
  163. assert prep_rel.node_type_row == rel.node_type_row
  164. assert prep_rel.node_type_column == rel.node_type_column
  165. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  166. assert prep_rel.adjacency_matrix_backward.shape == rel.adjacency_matrix_backward.shape
  167. assert len(prep_rel.edges_pos.train) == 8
  168. assert len(prep_rel.edges_pos.val) == 1
  169. assert len(prep_rel.edges_pos.test) == 1
  170. assert len(prep_rel.edges_neg.train) == 8
  171. assert len(prep_rel.edges_neg.val) == 1
  172. assert len(prep_rel.edges_neg.test) == 1
  173. assert len(prep_rel.edges_back_pos.train) == 8
  174. assert len(prep_rel.edges_back_pos.val) == 1
  175. assert len(prep_rel.edges_back_pos.test) == 1
  176. assert len(prep_rel.edges_back_neg.train) == 8
  177. assert len(prep_rel.edges_back_neg.val) == 1
  178. assert len(prep_rel.edges_back_neg.test) == 1
  179. # def prepare_relation(r, ratios):
  180. # adj_mat = r.adjacency_matrix
  181. # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  182. #
  183. # if r.node_type_row == r.node_type_column:
  184. # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  185. # else:
  186. # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  187. #
  188. # return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  189. # adj_mat_train, edges_pos, edges_neg)