IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

27 lines
898B

  1. import torch
  2. def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio):
  3. if train_ratio + val_ratio + test_ratio != 1.0:
  4. raise ValueError('Train, validation and test ratios must add up to 1')
  5. edges = torch.nonzero(adj_mat)
  6. order = torch.randperm(len(edges))
  7. edges = edges[order, :]
  8. n = round(len(edges) * train_ratio)
  9. edges_train = edges[:n]
  10. n_1 = round(len(edges) * (train_ratio + val_ratio))
  11. edges_val = edges[n:n_1]
  12. edges_test = edges[n_1:]
  13. adj_mat_train = torch.zeros_like(adj_mat)
  14. adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
  15. adj_mat_val = torch.zeros_like(adj_mat)
  16. adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
  17. adj_mat_test = torch.zeros_like(adj_mat)
  18. adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
  19. return adj_mat_train, adj_mat_val, adj_mat_test