IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

194 rindas
7.0KB

  1. from .data import Data
  2. from .model import TrainingBatch
  3. import torch
  4. from functools import reduce
  5. def _shuffle(x: torch.Tensor) -> torch.Tensor:
  6. order = torch.randperm(len(x))
  7. return x[order]
  8. def _same_data_org(pos_data: Data, neg_data: Data):
  9. if len(pos_data.vertex_types) != len(neg_data.vertex_types):
  10. return False
  11. test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
  12. and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
  13. for i in range(len(pos_data.vertex_types)) ]
  14. if not all(test):
  15. return False
  16. if not set(pos_data.edge_types.keys()) == \
  17. set(neg_data.edge_types.keys()):
  18. return False
  19. test = [ pos_data.edge_types[i].name == \
  20. neg_data.edge_types[i].name \
  21. and pos_data.edge_types[i].vertex_type_row == \
  22. neg_data.edge_types[i].vertex_type_row \
  23. and pos_data.edge_types[i].vertex_type_column == \
  24. neg_data.edge_types[i].vertex_type_column \
  25. and len(pos_data.edge_types[i].adjacency_matrices) == \
  26. len(neg_data.edge_types[i].adjacency_matrices) \
  27. for i in pos_data.edge_types.keys() ]
  28. if not all(test):
  29. return False
  30. test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
  31. len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
  32. for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \
  33. for i in pos_data.edge_types.keys() ]
  34. test = reduce(list.__add__, test, [])
  35. if not all(test):
  36. return False
  37. return True
  38. class DualBatcher(object):
  39. def __init__(self, pos_data: Data, neg_data: Data,
  40. batch_size: int=512, shuffle: bool=True) -> None:
  41. if not isinstance(pos_data, Data):
  42. raise TypeError('pos_data must be an instance of Data')
  43. if not isinstance(neg_data, Data):
  44. raise TypeError('neg_data must be an instance of Data')
  45. if not _same_data_org(pos_data, neg_data):
  46. raise ValueError('pos_data and neg_data must have the same organization')
  47. self.pos_data = pos_data
  48. self.neg_data = neg_data
  49. self.batch_size = int(batch_size)
  50. self.shuffle = bool(shuffle)
  51. def get_edge_lists(self, data: Data):
  52. edge_types = list(data.edge_types.items())
  53. edge_keys = [ a[0] for a in edge_types ]
  54. edge_types = [ a[1] for a in edge_types ]
  55. edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
  56. for adj_mat in et.adjacency_matrices ] \
  57. for et in edge_types ]
  58. if self.shuffle:
  59. edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
  60. for edge_lst in edge_lists ]
  61. offsets = [ [ 0 ] * len(et.adjacency_matrices) \
  62. for et in edge_types ]
  63. return (edge_keys, edge_types, edge_lists, offsets)
  64. def get_candidates(self, edge_lists, offsets):
  65. candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
  66. if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
  67. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
  68. if len(candidates) == 0:
  69. return None, None
  70. edge_idx = torch.randint(0, len(candidates), (1,)).item()
  71. edge_idx = candidates[edge_idx]
  72. candidates = [ rel_idx \
  73. for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
  74. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
  75. rel_idx = torch.randint(0, len(candidates), (1,)).item()
  76. rel_idx = candidates[rel_idx]
  77. return edge_idx, rel_idx
  78. def take_edges(self, edge_idx, rel_idx, edge_lists, offsets,
  79. edge_types, target_value):
  80. lst = edge_lists[edge_idx][rel_idx]
  81. et = edge_types[edge_idx]
  82. ofs = offsets[edge_idx][rel_idx]
  83. lst = lst[ofs:ofs+self.batch_size]
  84. offsets[edge_idx][rel_idx] += self.batch_size
  85. res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
  86. rel_idx, lst, torch.full(( len(lst), ), target_value,
  87. dtype=torch.float32))
  88. return res
  89. def __iter__(self):
  90. pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \
  91. self.get_edge_lists(self.pos_data)
  92. neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \
  93. self.get_edge_lists(self.neg_data)
  94. while True:
  95. edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets)
  96. if edge_idx is None:
  97. return
  98. pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists,
  99. pos_offsets, pos_edge_types, 1)
  100. neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists,
  101. neg_offsets, neg_edge_types, 0)
  102. yield (pos_batch, neg_batch)
  103. class Batcher(object):
  104. def __init__(self, data: Data, batch_size: int=512,
  105. shuffle: bool=True) -> None:
  106. if not isinstance(data, Data):
  107. raise TypeError('data must be an instance of Data')
  108. self.data = data
  109. self.batch_size = int(batch_size)
  110. self.shuffle = bool(shuffle)
  111. def __iter__(self) -> TrainingBatch:
  112. edge_types = list(self.data.edge_types.values())
  113. edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
  114. for adj_mat in et.adjacency_matrices ] \
  115. for et in edge_types ]
  116. if self.shuffle:
  117. edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
  118. for edge_lst in edge_lists ]
  119. offsets = [ [ 0 ] * len(et.adjacency_matrices) \
  120. for et in edge_types ]
  121. while True:
  122. candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
  123. if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
  124. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
  125. if len(candidates) == 0:
  126. break
  127. edge_idx = torch.randint(0, len(candidates), (1,)).item()
  128. edge_idx = candidates[edge_idx]
  129. candidates = [ rel_idx \
  130. for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
  131. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
  132. rel_idx = torch.randint(0, len(candidates), (1,)).item()
  133. rel_idx = candidates[rel_idx]
  134. lst = edge_lists[edge_idx][rel_idx]
  135. et = edge_types[edge_idx]
  136. ofs = offsets[edge_idx][rel_idx]
  137. lst = lst[ofs:ofs+self.batch_size]
  138. offsets[edge_idx][rel_idx] += self.batch_size
  139. b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
  140. rel_idx, lst, torch.full((len(lst),), self.data.target_value,
  141. dtype=torch.float32))
  142. yield b