IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

125 行
5.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from icosagon.trainprep import TrainValTest, \
  6. train_val_test_split_edges, \
  7. get_edges_and_degrees, \
  8. prepare_adj_mat, \
  9. prepare_relation_type
  10. import torch
  11. import pytest
  12. import numpy as np
  13. from itertools import chain
  14. from icosagon.data import RelationType
  15. def test_train_val_test_split_edges_01():
  16. edges = torch.randint(0, 10, (10, 2))
  17. with pytest.raises(ValueError):
  18. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  19. with pytest.raises(ValueError):
  20. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  21. with pytest.raises(ValueError):
  22. _ = train_val_test_split_edges(edges, None)
  23. with pytest.raises(ValueError):
  24. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  25. with pytest.raises(ValueError):
  26. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  27. with pytest.raises(ValueError):
  28. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  29. with pytest.raises(ValueError):
  30. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  31. with pytest.raises(ValueError):
  32. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  33. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  34. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  35. res.test.shape == (1, 2)
  36. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  37. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  38. res.test.shape == (2, 2)
  39. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  40. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  41. res.test.shape == (0, 2)
  42. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  43. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  44. res.test.shape == (5, 2)
  45. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  46. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  47. res.test.shape == (10, 2)
  48. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  49. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  50. res.test.shape == (0, 2)
  51. def test_train_val_test_split_edges_02():
  52. edges = torch.randint(0, 30, (30, 2))
  53. ratios = TrainValTest(.8, .1, .1)
  54. res = train_val_test_split_edges(edges, ratios)
  55. edges = [ tuple(a) for a in edges ]
  56. res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
  57. assert all([ a in edges for a in res ])
  58. def test_get_edges_and_degrees_01():
  59. adj_mat_dense = (torch.rand((10, 10)) > .5)
  60. adj_mat_sparse = adj_mat_dense.to_sparse()
  61. edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
  62. edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
  63. assert torch.all(degrees_dense == degrees_sparse)
  64. edges_dense = [ tuple(a) for a in edges_dense ]
  65. edges_sparse = [ tuple(a) for a in edges_dense ]
  66. assert len(edges_dense) == len(edges_sparse)
  67. assert all([ a in edges_dense for a in edges_sparse ])
  68. assert all([ a in edges_sparse for a in edges_dense ])
  69. # assert torch.all(edges_dense == edges_sparse)
  70. def test_prepare_adj_mat_01():
  71. adj_mat = (torch.rand((10, 10)) > .5)
  72. adj_mat = adj_mat.to_sparse()
  73. ratios = TrainValTest(.8, .1, .1)
  74. _ = prepare_adj_mat(adj_mat, ratios)
  75. def test_prepare_adj_mat_02():
  76. adj_mat = (torch.rand((10, 10)) > .5)
  77. adj_mat = adj_mat.to_sparse()
  78. ratios = TrainValTest(.8, .1, .1)
  79. (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
  80. assert isinstance(adj_mat_train, torch.Tensor)
  81. assert adj_mat_train.is_sparse
  82. assert adj_mat_train.shape == adj_mat.shape
  83. assert adj_mat_train.dtype == adj_mat.dtype
  84. assert isinstance(edges_pos, TrainValTest)
  85. assert isinstance(edges_neg, TrainValTest)
  86. for a in ['train', 'val', 'test']:
  87. for b in [edges_pos, edges_neg]:
  88. edges = getattr(b, a)
  89. assert isinstance(edges, torch.Tensor)
  90. assert len(edges.shape) == 2
  91. assert edges.shape[1] == 2
  92. def test_prepare_relation_type_01():
  93. adj_mat = (torch.rand((10, 10)) > .5)
  94. r = RelationType('Test', 0, 0, adj_mat, True)
  95. ratios = TrainValTest(.8, .1, .1)
  96. _ = prepare_relation_type(r, ratios)
  97. # def prepare_relation(r, ratios):
  98. # adj_mat = r.adjacency_matrix
  99. # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  100. #
  101. # if r.node_type_row == r.node_type_column:
  102. # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  103. # else:
  104. # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  105. #
  106. # return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  107. # adj_mat_train, edges_pos, edges_neg)