IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

154 行
6.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.trainprep import TrainValTest, \
  6. train_val_test_split_edges, \
  7. get_edges_and_degrees, \
  8. prepare_adj_mat, \
  9. prepare_relation_type, \
  10. prep_rel_one_node_type
  11. import torch
  12. import pytest
  13. import numpy as np
  14. from itertools import chain
  15. from icosagon.data import RelationType
  16. def test_train_val_test_split_edges_01():
  17. edges = torch.randint(0, 10, (10, 2))
  18. with pytest.raises(ValueError):
  19. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  20. with pytest.raises(ValueError):
  21. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  22. with pytest.raises(ValueError):
  23. _ = train_val_test_split_edges(edges, None)
  24. with pytest.raises(ValueError):
  25. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  26. with pytest.raises(ValueError):
  27. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  28. with pytest.raises(ValueError):
  29. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  30. with pytest.raises(ValueError):
  31. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  32. with pytest.raises(ValueError):
  33. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  34. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  35. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  36. res.test.shape == (1, 2)
  37. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  38. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  39. res.test.shape == (2, 2)
  40. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  41. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  42. res.test.shape == (0, 2)
  43. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  44. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  45. res.test.shape == (5, 2)
  46. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  47. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  48. res.test.shape == (10, 2)
  49. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  50. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  51. res.test.shape == (0, 2)
  52. def test_train_val_test_split_edges_02():
  53. edges = torch.randint(0, 30, (30, 2))
  54. ratios = TrainValTest(.8, .1, .1)
  55. res = train_val_test_split_edges(edges, ratios)
  56. edges = [ tuple(a) for a in edges ]
  57. res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
  58. assert all([ a in edges for a in res ])
  59. def test_get_edges_and_degrees_01():
  60. adj_mat_dense = (torch.rand((10, 10)) > .5)
  61. adj_mat_sparse = adj_mat_dense.to_sparse()
  62. edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
  63. edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
  64. assert torch.all(degrees_dense == degrees_sparse)
  65. edges_dense = [ tuple(a) for a in edges_dense ]
  66. edges_sparse = [ tuple(a) for a in edges_dense ]
  67. assert len(edges_dense) == len(edges_sparse)
  68. assert all([ a in edges_dense for a in edges_sparse ])
  69. assert all([ a in edges_sparse for a in edges_dense ])
  70. # assert torch.all(edges_dense == edges_sparse)
  71. def test_prepare_adj_mat_01():
  72. adj_mat = (torch.rand((10, 10)) > .5)
  73. adj_mat = adj_mat.to_sparse()
  74. ratios = TrainValTest(.8, .1, .1)
  75. _ = prepare_adj_mat(adj_mat, ratios)
  76. def test_prepare_adj_mat_02():
  77. adj_mat = (torch.rand((10, 10)) > .5)
  78. adj_mat = adj_mat.to_sparse()
  79. ratios = TrainValTest(.8, .1, .1)
  80. (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
  81. assert isinstance(adj_mat_train, torch.Tensor)
  82. assert adj_mat_train.is_sparse
  83. assert adj_mat_train.shape == adj_mat.shape
  84. assert adj_mat_train.dtype == adj_mat.dtype
  85. assert isinstance(edges_pos, TrainValTest)
  86. assert isinstance(edges_neg, TrainValTest)
  87. for a in ['train', 'val', 'test']:
  88. for b in [edges_pos, edges_neg]:
  89. edges = getattr(b, a)
  90. assert isinstance(edges, torch.Tensor)
  91. assert len(edges.shape) == 2
  92. assert edges.shape[1] == 2
  93. def test_prepare_relation_type_01():
  94. adj_mat = (torch.rand((10, 10)) > .5)
  95. r = RelationType('Test', 0, 0, adj_mat, True)
  96. ratios = TrainValTest(.8, .1, .1)
  97. _ = prepare_relation_type(r, ratios, False)
  98. def test_prep_rel_one_node_type_01():
  99. adj_mat = torch.zeros(100)
  100. perm = torch.randperm(100)
  101. adj_mat[perm[:10]] = 1
  102. adj_mat = adj_mat.view(10, 10)
  103. rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
  104. ratios = TrainValTest(.8, .1, .1)
  105. prep_rel = prep_rel_one_node_type(rel, ratios)
  106. assert prep_rel.name == rel.name
  107. assert prep_rel.node_type_row == rel.node_type_row
  108. assert prep_rel.node_type_column == rel.node_type_column
  109. assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
  110. assert prep_rel.adjacency_matrix_backward is None
  111. assert len(prep_rel.edges_pos.train) == 8
  112. assert len(prep_rel.edges_pos.val) == 1
  113. assert len(prep_rel.edges_pos.test) == 1
  114. assert len(prep_rel.edges_neg.train) == 8
  115. assert len(prep_rel.edges_neg.val) == 1
  116. assert len(prep_rel.edges_neg.test) == 1
  117. assert len(prep_rel.edges_back_pos.train) == 0
  118. assert len(prep_rel.edges_back_pos.val) == 0
  119. assert len(prep_rel.edges_back_pos.test) == 0
  120. assert len(prep_rel.edges_back_neg.train) == 0
  121. assert len(prep_rel.edges_back_neg.val) == 0
  122. assert len(prep_rel.edges_back_neg.test) == 0
  123. # def prepare_relation(r, ratios):
  124. # adj_mat = r.adjacency_matrix
  125. # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  126. #
  127. # if r.node_type_row == r.node_type_column:
  128. # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  129. # else:
  130. # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  131. #
  132. # return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  133. # adj_mat_train, edges_pos, edges_neg)