IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
2.6KB

  1. from icosagon.trainprep import TrainValTest, \
  2. train_val_test_split_edges
  3. import torch
  4. import pytest
  5. import numpy as np
  6. def test_train_val_test_split_edges_01():
  7. edges = torch.randint(0, 10, (10, 2))
  8. with pytest.raises(ValueError):
  9. _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
  10. with pytest.raises(ValueError):
  11. _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
  12. with pytest.raises(ValueError):
  13. _ = train_val_test_split_edges(edges, None)
  14. with pytest.raises(ValueError):
  15. _ = train_val_test_split_edges(edges, (.8, .1, .1))
  16. with pytest.raises(ValueError):
  17. _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
  18. with pytest.raises(ValueError):
  19. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
  20. with pytest.raises(ValueError):
  21. _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
  22. with pytest.raises(ValueError):
  23. _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
  24. res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
  25. assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
  26. res.test.shape == (1, 2)
  27. res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
  28. assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
  29. res.test.shape == (2, 2)
  30. res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
  31. assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
  32. res.test.shape == (0, 2)
  33. res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
  34. assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
  35. res.test.shape == (5, 2)
  36. res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
  37. assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
  38. res.test.shape == (10, 2)
  39. res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
  40. assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
  41. res.test.shape == (0, 2)
  42. # if ratios.train + ratios.val + ratios.test != 1.0:
  43. # raise ValueError('Train, validation and test ratios must add up to 1')
  44. #
  45. # order = torch.randperm(len(edges))
  46. # edges = edges[order, :]
  47. # n = round(len(edges) * ratios.train)
  48. # edges_train = edges[:n]
  49. # n_1 = round(len(edges) * (ratios.train + ratios.val))
  50. # edges_val = edges[n:n_1]
  51. # edges_test = edges[n_1:]
  52. #
  53. # return TrainValTest(edges_train, edges_val, edges_test)