|
- from icosagon.trainprep import TrainValTest, \
- train_val_test_split_edges
- import torch
- import pytest
- import numpy as np
-
-
- def test_train_val_test_split_edges_01():
- edges = torch.randint(0, 10, (10, 2))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(edges, None)
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(edges, (.8, .1, .1))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
- with pytest.raises(ValueError):
- _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
- res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
- assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
- res.test.shape == (1, 2)
- res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
- assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
- res.test.shape == (2, 2)
- res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
- assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
- res.test.shape == (0, 2)
- res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
- assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
- res.test.shape == (5, 2)
- res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
- assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
- res.test.shape == (10, 2)
- res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
- assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
- res.test.shape == (0, 2)
-
-
-
- # if ratios.train + ratios.val + ratios.test != 1.0:
- # raise ValueError('Train, validation and test ratios must add up to 1')
- #
- # order = torch.randperm(len(edges))
- # edges = edges[order, :]
- # n = round(len(edges) * ratios.train)
- # edges_train = edges[:n]
- # n_1 = round(len(edges) * (ratios.train + ratios.val))
- # edges_val = edges[n:n_1]
- # edges_test = edges[n_1:]
- #
- # return TrainValTest(edges_train, edges_val, edges_test)
|