|
- from decagon_pytorch.data import Data
- import torch
- from decagon_pytorch.splits import train_val_test_split_adj_mat
- import pytest
-
-
- def _gen_adj_mat(n_rows, n_cols):
- res = torch.rand((n_rows, n_cols)).round()
- if n_rows == n_cols:
- res -= torch.diag(torch.diag(res))
- a, b = torch.triu_indices(n_rows, n_cols)
- res[a, b] = res.transpose(0, 1)[a, b]
- return res
-
-
- def train_val_test_split_1(data, train_ratio=0.8,
- val_ratio=0.1, test_ratio=0.1):
-
- if train_ratio + val_ratio + test_ratio != 1.0:
- raise ValueError('Train, validation and test ratios must add up to 1')
-
- d_train = Data()
- d_val = Data()
- d_test = Data()
-
- for (node_type_row, node_type_col), rels in data.relation_types.items():
- for r in rels:
- adj_train, adj_val, adj_test = train_val_test_split_adj_mat(r.adjacency_matrix)
- d_train.add_relation_type(r.name, node_type_row, node_type_col, adj_train)
- d_val.add_relation_type(r.name, node_type_row, node_type_col, adj_train + adj_val)
-
-
- def train_val_test_split_2(data, train_ratio, val_ratio, test_ratio):
- if train_ratio + val_ratio + test_ratio != 1.0:
- raise ValueError('Train, validation and test ratios must add up to 1')
- for (node_type_row, node_type_col), rels in data.relation_types.items():
- for r in rels:
- adj_mat = r.adjacency_matrix
- edges = torch.nonzero(adj_mat)
- order = torch.randperm(len(edges))
- edges = edges[order, :]
- n = round(len(edges) * train_ratio)
- edges_train = edges[:n]
- n_1 = round(len(edges) * (train_ratio + val_ratio))
- edges_val = edges[n:n_1]
- edges_test = edges[n_1:]
- if len(edges_train) * len(edges_val) * len(edges_test) == 0:
- raise ValueError('Not enough edges to split into train/val/test sets for: ' + r.name)
-
-
- def test_train_val_test_split_adj_mat():
- adj_mat = _gen_adj_mat(50, 100)
- adj_mat_train, adj_mat_val, adj_mat_test = \
- train_val_test_split_adj_mat(adj_mat, train_ratio=0.8,
- val_ratio=0.1, test_ratio=0.1)
-
- assert adj_mat.shape == adj_mat_train.shape == \
- adj_mat_val.shape == adj_mat_test.shape
-
- edges_train = torch.nonzero(adj_mat_train)
- edges_val = torch.nonzero(adj_mat_val)
- edges_test = torch.nonzero(adj_mat_test)
-
- edges_train = set(map(tuple, edges_train.tolist()))
- edges_val = set(map(tuple, edges_val.tolist()))
- edges_test = set(map(tuple, edges_test.tolist()))
-
- assert edges_train.intersection(edges_val) == set()
- assert edges_train.intersection(edges_test) == set()
- assert edges_test.intersection(edges_val) == set()
-
- assert torch.all(adj_mat_train + adj_mat_val + adj_mat_test == adj_mat)
-
- # assert torch.all((edges_train != edges_val).sum(1).to(torch.bool))
- # assert torch.all((edges_train != edges_test).sum(1).to(torch.bool))
- # assert torch.all((edges_val != edges_test).sum(1).to(torch.bool))
-
-
- @pytest.mark.skip
- def test_splits_01():
- d = Data()
- d.add_node_type('Gene', 1000)
- d.add_node_type('Drug', 100)
- d.add_relation_type('Interaction', 0, 0,
- _gen_adj_mat(1000, 1000))
- d.add_relation_type('Target', 1, 0,
- _gen_adj_mat(100, 1000))
- d.add_relation_type('Side Effect: Insomnia', 1, 1,
- _gen_adj_mat(100, 100))
- d.add_relation_type('Side Effect: Incontinence', 1, 1,
- _gen_adj_mat(100, 100))
- d.add_relation_type('Side Effect: Abdominal pain', 1, 1,
- _gen_adj_mat(100, 100))
-
- d_train, d_val, d_test = train_val_test_split(d, 0.8, 0.1, 0.1)
|