diff --git a/src/decagon_pytorch/data/trainprep.py b/src/decagon_pytorch/data/trainprep.py new file mode 100644 index 0000000..18373ea --- /dev/null +++ b/src/decagon_pytorch/data/trainprep.py @@ -0,0 +1,2 @@ +def trainprep(data): + pass diff --git a/src/decagon_pytorch/splits.py b/src/decagon_pytorch/splits.py index e0394cc..9d219b6 100644 --- a/src/decagon_pytorch/splits.py +++ b/src/decagon_pytorch/splits.py @@ -1,7 +1,40 @@ import torch +from .data import Data, \ + AdjListData -def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio): +def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio, + return_edges=False): + + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError('Train, validation and test ratios must add up to 1') + + edges = torch.nonzero(adj_mat) + order = torch.randperm(len(edges)) + edges = edges[order, :] + n = round(len(edges) * train_ratio) + edges_train = edges[:n] + n_1 = round(len(edges) * (train_ratio + val_ratio)) + edges_val = edges[n:n_1] + edges_test = edges[n_1:] + + adj_mat_train = torch.zeros_like(adj_mat) + adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1 + + adj_mat_val = torch.zeros_like(adj_mat) + adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1 + + adj_mat_test = torch.zeros_like(adj_mat) + adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1 + + res = (adj_mat_train, adj_mat_val, adj_mat_test) + if return_edges: + res += (edges_train, edges_val, edges_test) + + return res + + +def train_val_test_split_edges(adj_mat, train_ratio, val_ratio, test_ratio): if train_ratio + val_ratio + test_ratio != 1.0: raise ValueError('Train, validation and test ratios must add up to 1') @@ -23,4 +56,5 @@ def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio): adj_mat_test = torch.zeros_like(adj_mat) adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1 - return adj_mat_train, adj_mat_val, adj_mat_test + return adj_mat_train, adj_mat_val, adj_mat_test, \ + edges_train, edges_val, edges_test