import torch def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio): if train_ratio + val_ratio + test_ratio != 1.0: raise ValueError('Train, validation and test ratios must add up to 1') edges = torch.nonzero(adj_mat) order = torch.randperm(len(edges)) edges = edges[order, :] n = round(len(edges) * train_ratio) edges_train = edges[:n] n_1 = round(len(edges) * (train_ratio + val_ratio)) edges_val = edges[n:n_1] edges_test = edges[n_1:] adj_mat_train = torch.zeros_like(adj_mat) adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1 adj_mat_val = torch.zeros_like(adj_mat) adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1 adj_mat_test = torch.zeros_like(adj_mat) adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1 return adj_mat_train, adj_mat_val, adj_mat_test