IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Add trainprep.

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
e201608caf
2 muutettua tiedostoa jossa 38 lisäystä ja 2 poistoa
  1. +2
    -0
      src/decagon_pytorch/data/trainprep.py
  2. +36
    -2
      src/decagon_pytorch/splits.py

+ 2
- 0
src/decagon_pytorch/data/trainprep.py Näytä tiedosto

@@ -0,0 +1,2 @@
def trainprep(data):
pass

+ 36
- 2
src/decagon_pytorch/splits.py Näytä tiedosto

@@ -1,7 +1,40 @@
import torch
from .data import Data, \
AdjListData
def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio,
return_edges=False):
if train_ratio + val_ratio + test_ratio != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
edges = torch.nonzero(adj_mat)
order = torch.randperm(len(edges))
edges = edges[order, :]
n = round(len(edges) * train_ratio)
edges_train = edges[:n]
n_1 = round(len(edges) * (train_ratio + val_ratio))
edges_val = edges[n:n_1]
edges_test = edges[n_1:]
adj_mat_train = torch.zeros_like(adj_mat)
adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
adj_mat_val = torch.zeros_like(adj_mat)
adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
adj_mat_test = torch.zeros_like(adj_mat)
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
res = (adj_mat_train, adj_mat_val, adj_mat_test)
if return_edges:
res += (edges_train, edges_val, edges_test)
return res
def train_val_test_split_edges(adj_mat, train_ratio, val_ratio, test_ratio):
if train_ratio + val_ratio + test_ratio != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
@@ -23,4 +56,5 @@ def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
adj_mat_test = torch.zeros_like(adj_mat)
adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
return adj_mat_train, adj_mat_val, adj_mat_test
return adj_mat_train, adj_mat_val, adj_mat_test, \
edges_train, edges_val, edges_test

Loading…
Peruuta
Tallenna