| @@ -21,7 +21,7 @@ def _same_data_org(pos_data: Data, neg_data: Data): | |||
| if not set(pos_data.edge_types.keys()) == \ | |||
| set(neg_data.edge_types.keys()): | |||
| return False | |||
| test = [ pos_data.edge_types[i].name == \ | |||
| @@ -65,8 +65,78 @@ class DualBatcher(object): | |||
| self.batch_size = int(batch_size) | |||
| self.shuffle = bool(shuffle) | |||
| def get_edge_lists(self, data: Data): | |||
| edge_types = list(data.edge_types.items()) | |||
| edge_keys = [ a[0] for a in edge_types ] | |||
| edge_types = [ a[1] for a in edge_types ] | |||
| edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ | |||
| for adj_mat in et.adjacency_matrices ] \ | |||
| for et in edge_types ] | |||
| if self.shuffle: | |||
| edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ | |||
| for edge_lst in edge_lists ] | |||
| offsets = [ [ 0 ] * len(et.adjacency_matrices) \ | |||
| for et in edge_types ] | |||
| return (edge_keys, edge_types, edge_lists, offsets) | |||
| def get_candidates(self, edge_lists, offsets): | |||
| candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ | |||
| if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ | |||
| if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] | |||
| if len(candidates) == 0: | |||
| return None, None | |||
| edge_idx = torch.randint(0, len(candidates), (1,)).item() | |||
| edge_idx = candidates[edge_idx] | |||
| candidates = [ rel_idx \ | |||
| for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ | |||
| if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] | |||
| rel_idx = torch.randint(0, len(candidates), (1,)).item() | |||
| rel_idx = candidates[rel_idx] | |||
| return edge_idx, rel_idx | |||
| def take_edges(self, edge_idx, rel_idx, edge_lists, offsets, | |||
| edge_types, target_value): | |||
| lst = edge_lists[edge_idx][rel_idx] | |||
| et = edge_types[edge_idx] | |||
| ofs = offsets[edge_idx][rel_idx] | |||
| lst = lst[ofs:ofs+self.batch_size] | |||
| offsets[edge_idx][rel_idx] += self.batch_size | |||
| res = TrainingBatch(et.vertex_type_row, et.vertex_type_column, | |||
| rel_idx, lst, torch.full(len(lst), target_value, | |||
| dtype=torch.float32)) | |||
| return res | |||
| def __iter__(self): | |||
| pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \ | |||
| self.get_edge_lists(self.pos_data) | |||
| neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \ | |||
| self.get_edge_lists(self.neg_data) | |||
| while True: | |||
| edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets) | |||
| if edge_idx is None: | |||
| return | |||
| pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists, | |||
| pos_offsets, pos_edge_types, 1) | |||
| neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists, | |||
| neg_offsets, neg_edge_types, 0) | |||
| yield (pos_batch, neg_batch) | |||
| class Batcher(object): | |||