IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

Add train_val_test_split_adj_mat().

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
a74884dadd
2 fichiers modifiés avec 121 ajouts et 0 suppressions
  1. +26
    -0
      src/decagon_pytorch/splits.py
  2. +95
    -0
      tests/decagon_pytorch/test_splits.py

+ 26
- 0
src/decagon_pytorch/splits.py Voir le fichier

@@ -0,0 +1,26 @@
import torch
def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
if train_ratio + val_ratio + test_ratio != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
edges = torch.nonzero(adj_mat)
order = torch.randperm(len(edges))
edges = edges[order, :]
n = round(len(edges) * train_ratio)
edges_train = edges[:n]
n_1 = round(len(edges) * (train_ratio + val_ratio))
edges_val = edges[n:n_1]
edges_test = edges[n_1:]
adj_mat_train = torch.zeros_like(adj_mat)
adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
adj_mat_val = torch.zeros_like(adj_mat)
adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
adj_mat_test = torch.zeros_like(adj_mat)
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
return adj_mat_train, adj_mat_val, adj_mat_test

+ 95
- 0
tests/decagon_pytorch/test_splits.py Voir le fichier

@@ -0,0 +1,95 @@
from decagon_pytorch.data import Data
import torch
from decagon_pytorch.splits import train_val_test_split_adj_mat
import pytest
def _gen_adj_mat(n_rows, n_cols):
res = torch.rand((n_rows, n_cols)).round()
if n_rows == n_cols:
res -= torch.diag(torch.diag(res))
a, b = torch.triu_indices(n_rows, n_cols)
res[a, b] = res.transpose(0, 1)[a, b]
return res
def train_val_test_split_1(data, train_ratio=0.8,
val_ratio=0.1, test_ratio=0.1):
if train_ratio + val_ratio + test_ratio != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
d_train = Data()
d_val = Data()
d_test = Data()
for (node_type_row, node_type_col), rels in data.relation_types.items():
for r in rels:
adj_train, adj_val, adj_test = train_val_test_split_adj_mat(r.adjacency_matrix)
d_train.add_relation_type(r.name, node_type_row, node_type_col, adj_train)
d_val.add_relation_type(r.name, node_type_row, node_type_col, adj_train + adj_val)
def train_val_test_split_2(data, train_ratio, val_ratio, test_ratio):
if train_ratio + val_ratio + test_ratio != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
for (node_type_row, node_type_col), rels in data.relation_types.items():
for r in rels:
adj_mat = r.adjacency_matrix
edges = torch.nonzero(adj_mat)
order = torch.randperm(len(edges))
edges = edges[order, :]
n = round(len(edges) * train_ratio)
edges_train = edges[:n]
n_1 = round(len(edges) * (train_ratio + val_ratio))
edges_val = edges[n:n_1]
edges_test = edges[n_1:]
if len(edges_train) * len(edges_val) * len(edges_test) == 0:
raise ValueError('Not enough edges to split into train/val/test sets for: ' + r.name)
def test_train_val_test_split_adj_mat():
adj_mat = _gen_adj_mat(50, 100)
adj_mat_train, adj_mat_val, adj_mat_test = \
train_val_test_split_adj_mat(adj_mat, train_ratio=0.8,
val_ratio=0.1, test_ratio=0.1)
assert adj_mat.shape == adj_mat_train.shape == \
adj_mat_val.shape == adj_mat_test.shape
edges_train = torch.nonzero(adj_mat_train)
edges_val = torch.nonzero(adj_mat_val)
edges_test = torch.nonzero(adj_mat_test)
edges_train = set(map(tuple, edges_train.tolist()))
edges_val = set(map(tuple, edges_val.tolist()))
edges_test = set(map(tuple, edges_test.tolist()))
assert edges_train.intersection(edges_val) == set()
assert edges_train.intersection(edges_test) == set()
assert edges_test.intersection(edges_val) == set()
assert torch.all(adj_mat_train + adj_mat_val + adj_mat_test == adj_mat)
# assert torch.all((edges_train != edges_val).sum(1).to(torch.bool))
# assert torch.all((edges_train != edges_test).sum(1).to(torch.bool))
# assert torch.all((edges_val != edges_test).sum(1).to(torch.bool))
@pytest.mark.skip
def test_splits_01():
d = Data()
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
d.add_relation_type('Interaction', 0, 0,
_gen_adj_mat(1000, 1000))
d.add_relation_type('Target', 1, 0,
_gen_adj_mat(100, 1000))
d.add_relation_type('Side Effect: Insomnia', 1, 1,
_gen_adj_mat(100, 100))
d.add_relation_type('Side Effect: Incontinence', 1, 1,
_gen_adj_mat(100, 100))
d.add_relation_type('Side Effect: Abdominal pain', 1, 1,
_gen_adj_mat(100, 100))
d_train, d_val, d_test = train_val_test_split(d, 0.8, 0.1, 0.1)

Chargement…
Annuler
Enregistrer