from icosagon.trainprep import TrainValTest, \ train_val_test_split_edges import torch import pytest import numpy as np def test_train_val_test_split_edges_01(): edges = torch.randint(0, 10, (10, 2)) with pytest.raises(ValueError): _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5)) with pytest.raises(ValueError): _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2)) with pytest.raises(ValueError): _ = train_val_test_split_edges(edges, None) with pytest.raises(ValueError): _ = train_val_test_split_edges(edges, (.8, .1, .1)) with pytest.raises(ValueError): _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1)) with pytest.raises(ValueError): _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1)) with pytest.raises(ValueError): _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1)) with pytest.raises(ValueError): _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2)) res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1)) assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \ res.test.shape == (1, 2) res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2)) assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \ res.test.shape == (2, 2) res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0)) assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \ res.test.shape == (0, 2) res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5)) assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \ res.test.shape == (5, 2) res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.)) assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \ res.test.shape == (10, 2) res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0)) assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \ res.test.shape == (0, 2) # if ratios.train + ratios.val + ratios.test != 1.0: # raise ValueError('Train, validation and test ratios must add up to 1') # # order = torch.randperm(len(edges)) # edges = edges[order, :] # n = round(len(edges) * ratios.train) # edges_train = edges[:n] # n_1 = round(len(edges) * (ratios.train + ratios.val)) # edges_val = edges[n:n_1] # edges_test = edges[n_1:] # # return TrainValTest(edges_train, edges_val, edges_test)