IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
2.0KB

  1. from .data import Data, \
  2. EdgeType
  3. from typing import Tuple, \
  4. List
  5. from .util import _sparse_coo_tensor
  6. import torch
  7. def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
  8. ratios = list(ratios)
  9. if sum(ratios) != 1:
  10. raise ValueError('Sum of ratios must be 1')
  11. indices = adj_mat.indices()
  12. values = adj_mat.values()
  13. order = torch.randperm(indices.shape[1])
  14. indices = indices[:, order]
  15. values = values[order]
  16. ofs = 0
  17. res = []
  18. for r in ratios:
  19. # cnt = r * len(values)
  20. beg = int(ofs * len(values))
  21. end = int((ofs + r) * len(values))
  22. ofs += r
  23. ind = indices[:, beg:end]
  24. val = values[beg:end]
  25. res.append(_sparse_coo_tensor(ind, val, adj_mat.shape).coalesce())
  26. # ofs += cnt
  27. return res
  28. def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
  29. ratios = list(ratios)
  30. if sum(ratios) != 1:
  31. raise ValueError('Sum of ratios must be 1')
  32. res = [ split_adj_mat(adj_mat, ratios) \
  33. for adj_mat in et.adjacency_matrices ]
  34. res = [ EdgeType(et.name,
  35. et.vertex_type_row,
  36. et.vertex_type_column,
  37. [ a[i] for a in res ],
  38. et.decoder_factory,
  39. None ) for i in range(len(ratios)) ]
  40. return res
  41. def split_data(data: Data,
  42. ratios: List[float]):
  43. if not isinstance(data, Data):
  44. raise TypeError('data must be an instance of Data')
  45. ratios = list(ratios)
  46. if sum(ratios) != 1:
  47. raise ValueError('ratios must sum to 1')
  48. res = [ {} for _ in range(len(ratios)) ]
  49. for key, et in data.edge_types.items():
  50. for i, new_et in enumerate(split_edge_type(et, ratios)):
  51. res[i][key] = new_et
  52. res_1 = []
  53. for new_edge_types in res:
  54. d = Data()
  55. d.vertex_types = data.vertex_types
  56. d.edge_types = new_edge_types
  57. res_1.append(d)
  58. return res_1