IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.6KB

  1. from decagon_pytorch.data import Data
  2. import torch
  3. from decagon_pytorch.splits import train_val_test_split_adj_mat
  4. import pytest
  5. def _gen_adj_mat(n_rows, n_cols):
  6. res = torch.rand((n_rows, n_cols)).round()
  7. if n_rows == n_cols:
  8. res -= torch.diag(torch.diag(res))
  9. a, b = torch.triu_indices(n_rows, n_cols)
  10. res[a, b] = res.transpose(0, 1)[a, b]
  11. return res
  12. def train_val_test_split_1(data, train_ratio=0.8,
  13. val_ratio=0.1, test_ratio=0.1):
  14. if train_ratio + val_ratio + test_ratio != 1.0:
  15. raise ValueError('Train, validation and test ratios must add up to 1')
  16. d_train = Data()
  17. d_val = Data()
  18. d_test = Data()
  19. for (node_type_row, node_type_col), rels in data.relation_types.items():
  20. for r in rels:
  21. adj_train, adj_val, adj_test = train_val_test_split_adj_mat(r.adjacency_matrix)
  22. d_train.add_relation_type(r.name, node_type_row, node_type_col, adj_train)
  23. d_val.add_relation_type(r.name, node_type_row, node_type_col, adj_train + adj_val)
  24. def train_val_test_split_2(data, train_ratio, val_ratio, test_ratio):
  25. if train_ratio + val_ratio + test_ratio != 1.0:
  26. raise ValueError('Train, validation and test ratios must add up to 1')
  27. for (node_type_row, node_type_col), rels in data.relation_types.items():
  28. for r in rels:
  29. adj_mat = r.adjacency_matrix
  30. edges = torch.nonzero(adj_mat)
  31. order = torch.randperm(len(edges))
  32. edges = edges[order, :]
  33. n = round(len(edges) * train_ratio)
  34. edges_train = edges[:n]
  35. n_1 = round(len(edges) * (train_ratio + val_ratio))
  36. edges_val = edges[n:n_1]
  37. edges_test = edges[n_1:]
  38. if len(edges_train) * len(edges_val) * len(edges_test) == 0:
  39. raise ValueError('Not enough edges to split into train/val/test sets for: ' + r.name)
  40. def test_train_val_test_split_adj_mat():
  41. adj_mat = _gen_adj_mat(50, 100)
  42. adj_mat_train, adj_mat_val, adj_mat_test = \
  43. train_val_test_split_adj_mat(adj_mat, train_ratio=0.8,
  44. val_ratio=0.1, test_ratio=0.1)
  45. assert adj_mat.shape == adj_mat_train.shape == \
  46. adj_mat_val.shape == adj_mat_test.shape
  47. edges_train = torch.nonzero(adj_mat_train)
  48. edges_val = torch.nonzero(adj_mat_val)
  49. edges_test = torch.nonzero(adj_mat_test)
  50. edges_train = set(map(tuple, edges_train.tolist()))
  51. edges_val = set(map(tuple, edges_val.tolist()))
  52. edges_test = set(map(tuple, edges_test.tolist()))
  53. assert edges_train.intersection(edges_val) == set()
  54. assert edges_train.intersection(edges_test) == set()
  55. assert edges_test.intersection(edges_val) == set()
  56. assert torch.all(adj_mat_train + adj_mat_val + adj_mat_test == adj_mat)
  57. # assert torch.all((edges_train != edges_val).sum(1).to(torch.bool))
  58. # assert torch.all((edges_train != edges_test).sum(1).to(torch.bool))
  59. # assert torch.all((edges_val != edges_test).sum(1).to(torch.bool))
  60. @pytest.mark.skip
  61. def test_splits_01():
  62. d = Data()
  63. d.add_node_type('Gene', 1000)
  64. d.add_node_type('Drug', 100)
  65. d.add_relation_type('Interaction', 0, 0,
  66. _gen_adj_mat(1000, 1000))
  67. d.add_relation_type('Target', 1, 0,
  68. _gen_adj_mat(100, 1000))
  69. d.add_relation_type('Side Effect: Insomnia', 1, 1,
  70. _gen_adj_mat(100, 100))
  71. d.add_relation_type('Side Effect: Incontinence', 1, 1,
  72. _gen_adj_mat(100, 100))
  73. d.add_relation_type('Side Effect: Abdominal pain', 1, 1,
  74. _gen_adj_mat(100, 100))
  75. d_train, d_val, d_test = train_val_test_split(d, 0.8, 0.1, 0.1)