IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

60 рядки
1.5KB

  1. from .data import Data, \
  2. TrainingBatch, \
  3. EdgeType
  4. from typing import Tuple
  5. from .util import _sparse_coo_tensor
  6. def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
  7. indices = adj_mat.indices()
  8. values = adj_mat.values()
  9. order = torch.randperm(indices.shape[1])
  10. indices = indices[:, order]
  11. values = values[order]
  12. ofs = 0
  13. res = []
  14. for r in ratios:
  15. cnt = r * len(values)
  16. ind = indices[:, ofs:ofs+cnt]
  17. val = values[ofs:ofs+cnt]
  18. res.append(_sparse_coo_tensor(ind, val, adj_mat.shape))
  19. ofs += cnt
  20. return res
  21. def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
  22. res = [ [] for _ in range(len(et.adjacency_matrices)) ]
  23. for adj_mat in et.adjacency_matrices:
  24. for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)):
  25. res[i].append(new_adj_mat)
  26. return res
  27. def split_data(data: Data,
  28. ratios: List[float]):
  29. if not isinstance(data, Data):
  30. raise TypeError('data must be an instance of Data')
  31. ratios = list(ratios)
  32. if sum(ratios) != 1:
  33. raise ValueError('ratios must sum to 1')
  34. res = [ {} for _ in range(len(ratios)) ]
  35. for key, et in data.edge_types:
  36. for i, new_et in enumerate(split_edge_type(et, ratios)):
  37. res[i][key] = new_et
  38. res = [ Data(data.vertex_types, new_edge_types) \
  39. for new_edge_types in res ]
  40. return res