|
- import torch
-
-
- def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
- if train_ratio + val_ratio + test_ratio != 1.0:
- raise ValueError('Train, validation and test ratios must add up to 1')
-
- edges = torch.nonzero(adj_mat)
- order = torch.randperm(len(edges))
- edges = edges[order, :]
- n = round(len(edges) * train_ratio)
- edges_train = edges[:n]
- n_1 = round(len(edges) * (train_ratio + val_ratio))
- edges_val = edges[n:n_1]
- edges_test = edges[n_1:]
-
- adj_mat_train = torch.zeros_like(adj_mat)
- adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
-
- adj_mat_val = torch.zeros_like(adj_mat)
- adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
-
- adj_mat_test = torch.zeros_like(adj_mat)
- adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
-
- return adj_mat_train, adj_mat_val, adj_mat_test
|