IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
1.7KB

  1. from .data import Data, \
  2. EdgeType
  3. from typing import Tuple, \
  4. List
  5. from .util import _sparse_coo_tensor
  6. import torch
  7. def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
  8. indices = adj_mat.indices()
  9. values = adj_mat.values()
  10. order = torch.randperm(indices.shape[1])
  11. indices = indices[:, order]
  12. values = values[order]
  13. ofs = 0
  14. res = []
  15. for r in ratios:
  16. # cnt = r * len(values)
  17. beg = int(ofs * len(values))
  18. end = int((ofs + r) * len(values))
  19. ofs += r
  20. ind = indices[:, beg:end]
  21. val = values[beg:end]
  22. res.append(_sparse_coo_tensor(ind, val, adj_mat.shape))
  23. # ofs += cnt
  24. return res
  25. def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
  26. res = [ split_adj_mat(adj_mat, ratios) \
  27. for adj_mat in et.adjacency_matrices ]
  28. res = [ EdgeType(et.name,
  29. et.vertex_type_row,
  30. et.vertex_type_column,
  31. [ a[i] for a in res ],
  32. et.decoder_factory,
  33. None ) for i in range(len(ratios)) ]
  34. return res
  35. def split_data(data: Data,
  36. ratios: List[float]):
  37. if not isinstance(data, Data):
  38. raise TypeError('data must be an instance of Data')
  39. ratios = list(ratios)
  40. if sum(ratios) != 1:
  41. raise ValueError('ratios must sum to 1')
  42. res = [ {} for _ in range(len(ratios)) ]
  43. for key, et in data.edge_types.items():
  44. for i, new_et in enumerate(split_edge_type(et, ratios)):
  45. res[i][key] = new_et
  46. res_1 = []
  47. for new_edge_types in res:
  48. d = Data()
  49. d.vertex_types = data.vertex_types,
  50. d.edge_types = new_edge_types
  51. res_1.append(d)
  52. return res_1