|
@@ -1,7 +1,40 @@ |
|
|
import torch
|
|
|
import torch
|
|
|
|
|
|
from .data import Data, \
|
|
|
|
|
|
AdjListData
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
|
|
|
|
|
|
|
|
|
def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio,
|
|
|
|
|
|
return_edges=False):
|
|
|
|
|
|
|
|
|
|
|
|
if train_ratio + val_ratio + test_ratio != 1.0:
|
|
|
|
|
|
raise ValueError('Train, validation and test ratios must add up to 1')
|
|
|
|
|
|
|
|
|
|
|
|
edges = torch.nonzero(adj_mat)
|
|
|
|
|
|
order = torch.randperm(len(edges))
|
|
|
|
|
|
edges = edges[order, :]
|
|
|
|
|
|
n = round(len(edges) * train_ratio)
|
|
|
|
|
|
edges_train = edges[:n]
|
|
|
|
|
|
n_1 = round(len(edges) * (train_ratio + val_ratio))
|
|
|
|
|
|
edges_val = edges[n:n_1]
|
|
|
|
|
|
edges_test = edges[n_1:]
|
|
|
|
|
|
|
|
|
|
|
|
adj_mat_train = torch.zeros_like(adj_mat)
|
|
|
|
|
|
adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
|
|
|
|
|
|
|
|
|
|
|
|
adj_mat_val = torch.zeros_like(adj_mat)
|
|
|
|
|
|
adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
|
|
|
|
|
|
|
|
|
|
|
|
adj_mat_test = torch.zeros_like(adj_mat)
|
|
|
|
|
|
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
|
|
|
|
|
|
|
|
|
|
|
|
res = (adj_mat_train, adj_mat_val, adj_mat_test)
|
|
|
|
|
|
if return_edges:
|
|
|
|
|
|
res += (edges_train, edges_val, edges_test)
|
|
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_val_test_split_edges(adj_mat, train_ratio, val_ratio, test_ratio):
|
|
|
if train_ratio + val_ratio + test_ratio != 1.0:
|
|
|
if train_ratio + val_ratio + test_ratio != 1.0:
|
|
|
raise ValueError('Train, validation and test ratios must add up to 1')
|
|
|
raise ValueError('Train, validation and test ratios must add up to 1')
|
|
|
|
|
|
|
|
@@ -23,4 +56,5 @@ def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio): |
|
|
adj_mat_test = torch.zeros_like(adj_mat)
|
|
|
adj_mat_test = torch.zeros_like(adj_mat)
|
|
|
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
|
|
|
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
|
|
|
|
|
|
|
|
|
return adj_mat_train, adj_mat_val, adj_mat_test
|
|
|
|
|
|
|
|
|
return adj_mat_train, adj_mat_val, adj_mat_test, \
|
|
|
|
|
|
edges_train, edges_val, edges_test
|