IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
1.9KB

  1. import torch
  2. from .data import Data, \
  3. AdjListData
  4. def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio,
  5. return_edges=False):
  6. if train_ratio + val_ratio + test_ratio != 1.0:
  7. raise ValueError('Train, validation and test ratios must add up to 1')
  8. edges = torch.nonzero(adj_mat)
  9. order = torch.randperm(len(edges))
  10. edges = edges[order, :]
  11. n = round(len(edges) * train_ratio)
  12. edges_train = edges[:n]
  13. n_1 = round(len(edges) * (train_ratio + val_ratio))
  14. edges_val = edges[n:n_1]
  15. edges_test = edges[n_1:]
  16. adj_mat_train = torch.zeros_like(adj_mat)
  17. adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
  18. adj_mat_val = torch.zeros_like(adj_mat)
  19. adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
  20. adj_mat_test = torch.zeros_like(adj_mat)
  21. adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
  22. res = (adj_mat_train, adj_mat_val, adj_mat_test)
  23. if return_edges:
  24. res += (edges_train, edges_val, edges_test)
  25. return res
  26. def train_val_test_split_edges(adj_mat, train_ratio, val_ratio, test_ratio):
  27. if train_ratio + val_ratio + test_ratio != 1.0:
  28. raise ValueError('Train, validation and test ratios must add up to 1')
  29. edges = torch.nonzero(adj_mat)
  30. order = torch.randperm(len(edges))
  31. edges = edges[order, :]
  32. n = round(len(edges) * train_ratio)
  33. edges_train = edges[:n]
  34. n_1 = round(len(edges) * (train_ratio + val_ratio))
  35. edges_val = edges[n:n_1]
  36. edges_test = edges[n_1:]
  37. adj_mat_train = torch.zeros_like(adj_mat)
  38. adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
  39. adj_mat_val = torch.zeros_like(adj_mat)
  40. adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
  41. adj_mat_test = torch.zeros_like(adj_mat)
  42. adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
  43. return adj_mat_train, adj_mat_val, adj_mat_test, \
  44. edges_train, edges_val, edges_test