IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

124 řádky
4.4KB

  1. from .data import Data
  2. from .model import TrainingBatch
  3. import torch
  4. from functools import reduce
  5. def _shuffle(x: torch.Tensor) -> torch.Tensor:
  6. order = torch.randperm(len(x))
  7. return x[order]
  8. def _same_data_org(pos_data: Data, neg_data: Data):
  9. if len(pos_data.vertex_types) != len(neg_data.vertex_types):
  10. return False
  11. test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
  12. and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
  13. for i in range(len(pos_data.vertex_types)) ]
  14. if not all(test):
  15. return False
  16. if not set(pos_data.edge_types.keys()) == \
  17. set(neg_data.edge_types.keys()):
  18. return False
  19. test = [ pos_data.edge_types[i].name == \
  20. neg_data.edge_types[i].name \
  21. and pos_data.edge_types[i].vertex_type_row == \
  22. neg_data.edge_types[i].vertex_type_row \
  23. and pos_data.edge_types[i].vertex_type_column == \
  24. neg_data.edge_types[i].vertex_type_column \
  25. and len(pos_data.edge_types[i].adjacency_matrices) == \
  26. len(neg_data.edge_types[i].adjacency_matrices) \
  27. for i in pos_data.edge_types.keys() ]
  28. if not all(test):
  29. return False
  30. test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
  31. len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
  32. for k in range(len(pos_data.edge_types[i])) ] \
  33. for i in pos_data.edge_types.keys() ]
  34. test = reduce(list.__add__, test)
  35. if not all(test):
  36. return False
  37. return True
  38. class DualBatcher(object):
  39. def __init__(self, pos_data: Data, neg_data: Data,
  40. batch_size: int=512, shuffle: bool=True) -> None:
  41. if not isinstance(pos_data, Data):
  42. raise TypeError('pos_data must be an instance of Data')
  43. if not isinstance(neg_data, Data):
  44. raise TypeError('neg_data must be an instance of Data')
  45. if not _same_data_org(pos_data, neg_data):
  46. raise ValueError('pos_data and neg_data must have the same organization')
  47. self.pos_data = pos_data
  48. self.neg_data = neg_data
  49. self.batch_size = int(batch_size)
  50. self.shuffle = bool(shuffle)
  51. def __iter__(self):
  52. class Batcher(object):
  53. def __init__(self, data: Data, batch_size: int=512,
  54. shuffle: bool=True) -> None:
  55. if not isinstance(data, Data):
  56. raise TypeError('data must be an instance of Data')
  57. self.data = data
  58. self.batch_size = int(batch_size)
  59. self.shuffle = bool(shuffle)
  60. def __iter__(self) -> TrainingBatch:
  61. edge_types = list(self.data.edge_types.values())
  62. edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
  63. for adj_mat in et.adjacency_matrices ] \
  64. for et in edge_types ]
  65. if self.shuffle:
  66. edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
  67. for edge_lst in edge_lists ]
  68. offsets = [ [ 0 ] * len(et.adjacency_matrices) \
  69. for et in edge_types ]
  70. while True:
  71. candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
  72. if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
  73. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
  74. if len(candidates) == 0:
  75. break
  76. edge_idx = torch.randint(0, len(candidates), (1,)).item()
  77. edge_idx = candidates[edge_idx]
  78. candidates = [ rel_idx \
  79. for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
  80. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
  81. rel_idx = torch.randint(0, len(candidates), (1,)).item()
  82. rel_idx = candidates[rel_idx]
  83. lst = edge_lists[edge_idx][rel_idx]
  84. et = edge_types[edge_idx]
  85. ofs = offsets[edge_idx][rel_idx]
  86. lst = lst[ofs:ofs+self.batch_size]
  87. offsets[edge_idx][rel_idx] += self.batch_size
  88. b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
  89. rel_idx, lst, torch.full((len(lst),), self.data.target_value,
  90. dtype=torch.float32))
  91. yield b