diff --git a/src/decagon_pytorch/splits.py b/src/decagon_pytorch/splits.py new file mode 100644 index 0000000..e0394cc --- /dev/null +++ b/src/decagon_pytorch/splits.py @@ -0,0 +1,26 @@ +import torch + + +def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio): + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError('Train, validation and test ratios must add up to 1') + + edges = torch.nonzero(adj_mat) + order = torch.randperm(len(edges)) + edges = edges[order, :] + n = round(len(edges) * train_ratio) + edges_train = edges[:n] + n_1 = round(len(edges) * (train_ratio + val_ratio)) + edges_val = edges[n:n_1] + edges_test = edges[n_1:] + + adj_mat_train = torch.zeros_like(adj_mat) + adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1 + + adj_mat_val = torch.zeros_like(adj_mat) + adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1 + + adj_mat_test = torch.zeros_like(adj_mat) + adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1 + + return adj_mat_train, adj_mat_val, adj_mat_test diff --git a/tests/decagon_pytorch/test_splits.py b/tests/decagon_pytorch/test_splits.py new file mode 100644 index 0000000..79482fb --- /dev/null +++ b/tests/decagon_pytorch/test_splits.py @@ -0,0 +1,95 @@ +from decagon_pytorch.data import Data +import torch +from decagon_pytorch.splits import train_val_test_split_adj_mat +import pytest + + +def _gen_adj_mat(n_rows, n_cols): + res = torch.rand((n_rows, n_cols)).round() + if n_rows == n_cols: + res -= torch.diag(torch.diag(res)) + a, b = torch.triu_indices(n_rows, n_cols) + res[a, b] = res.transpose(0, 1)[a, b] + return res + + +def train_val_test_split_1(data, train_ratio=0.8, + val_ratio=0.1, test_ratio=0.1): + + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError('Train, validation and test ratios must add up to 1') + + d_train = Data() + d_val = Data() + d_test = Data() + + for (node_type_row, node_type_col), rels in data.relation_types.items(): + for r in rels: + adj_train, adj_val, adj_test = train_val_test_split_adj_mat(r.adjacency_matrix) + d_train.add_relation_type(r.name, node_type_row, node_type_col, adj_train) + d_val.add_relation_type(r.name, node_type_row, node_type_col, adj_train + adj_val) + + +def train_val_test_split_2(data, train_ratio, val_ratio, test_ratio): + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError('Train, validation and test ratios must add up to 1') + for (node_type_row, node_type_col), rels in data.relation_types.items(): + for r in rels: + adj_mat = r.adjacency_matrix + edges = torch.nonzero(adj_mat) + order = torch.randperm(len(edges)) + edges = edges[order, :] + n = round(len(edges) * train_ratio) + edges_train = edges[:n] + n_1 = round(len(edges) * (train_ratio + val_ratio)) + edges_val = edges[n:n_1] + edges_test = edges[n_1:] + if len(edges_train) * len(edges_val) * len(edges_test) == 0: + raise ValueError('Not enough edges to split into train/val/test sets for: ' + r.name) + + +def test_train_val_test_split_adj_mat(): + adj_mat = _gen_adj_mat(50, 100) + adj_mat_train, adj_mat_val, adj_mat_test = \ + train_val_test_split_adj_mat(adj_mat, train_ratio=0.8, + val_ratio=0.1, test_ratio=0.1) + + assert adj_mat.shape == adj_mat_train.shape == \ + adj_mat_val.shape == adj_mat_test.shape + + edges_train = torch.nonzero(adj_mat_train) + edges_val = torch.nonzero(adj_mat_val) + edges_test = torch.nonzero(adj_mat_test) + + edges_train = set(map(tuple, edges_train.tolist())) + edges_val = set(map(tuple, edges_val.tolist())) + edges_test = set(map(tuple, edges_test.tolist())) + + assert edges_train.intersection(edges_val) == set() + assert edges_train.intersection(edges_test) == set() + assert edges_test.intersection(edges_val) == set() + + assert torch.all(adj_mat_train + adj_mat_val + adj_mat_test == adj_mat) + + # assert torch.all((edges_train != edges_val).sum(1).to(torch.bool)) + # assert torch.all((edges_train != edges_test).sum(1).to(torch.bool)) + # assert torch.all((edges_val != edges_test).sum(1).to(torch.bool)) + + +@pytest.mark.skip +def test_splits_01(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Interaction', 0, 0, + _gen_adj_mat(1000, 1000)) + d.add_relation_type('Target', 1, 0, + _gen_adj_mat(100, 1000)) + d.add_relation_type('Side Effect: Insomnia', 1, 1, + _gen_adj_mat(100, 100)) + d.add_relation_type('Side Effect: Incontinence', 1, 1, + _gen_adj_mat(100, 100)) + d.add_relation_type('Side Effect: Abdominal pain', 1, 1, + _gen_adj_mat(100, 100)) + + d_train, d_val, d_test = train_val_test_split(d, 0.8, 0.1, 0.1)