From 6b15ec8c103e48c5b34c553e9a4abb4117b8c6a8 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 18 Aug 2020 13:59:29 +0200 Subject: [PATCH] Finish DualBatcher. --- src/triacontagon/batch.py | 72 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/src/triacontagon/batch.py b/src/triacontagon/batch.py index b93ab36..ec66f5f 100644 --- a/src/triacontagon/batch.py +++ b/src/triacontagon/batch.py @@ -21,7 +21,7 @@ def _same_data_org(pos_data: Data, neg_data: Data): if not set(pos_data.edge_types.keys()) == \ set(neg_data.edge_types.keys()): - + return False test = [ pos_data.edge_types[i].name == \ @@ -65,8 +65,78 @@ class DualBatcher(object): self.batch_size = int(batch_size) self.shuffle = bool(shuffle) + def get_edge_lists(self, data: Data): + edge_types = list(data.edge_types.items()) + edge_keys = [ a[0] for a in edge_types ] + edge_types = [ a[1] for a in edge_types ] + + edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ + for adj_mat in et.adjacency_matrices ] \ + for et in edge_types ] + + if self.shuffle: + edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ + for edge_lst in edge_lists ] + + offsets = [ [ 0 ] * len(et.adjacency_matrices) \ + for et in edge_types ] + + return (edge_keys, edge_types, edge_lists, offsets) + + def get_candidates(self, edge_lists, offsets): + candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ + if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ + if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] + + if len(candidates) == 0: + return None, None + + edge_idx = torch.randint(0, len(candidates), (1,)).item() + edge_idx = candidates[edge_idx] + candidates = [ rel_idx \ + for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ + if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] + + rel_idx = torch.randint(0, len(candidates), (1,)).item() + rel_idx = candidates[rel_idx] + + return edge_idx, rel_idx + + def take_edges(self, edge_idx, rel_idx, edge_lists, offsets, + edge_types, target_value): + + lst = edge_lists[edge_idx][rel_idx] + et = edge_types[edge_idx] + ofs = offsets[edge_idx][rel_idx] + lst = lst[ofs:ofs+self.batch_size] + offsets[edge_idx][rel_idx] += self.batch_size + + res = TrainingBatch(et.vertex_type_row, et.vertex_type_column, + rel_idx, lst, torch.full(len(lst), target_value, + dtype=torch.float32)) + + return res + def __iter__(self): + pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \ + self.get_edge_lists(self.pos_data) + + neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \ + self.get_edge_lists(self.neg_data) + + while True: + edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets) + + if edge_idx is None: + return + + pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists, + pos_offsets, pos_edge_types, 1) + + neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists, + neg_offsets, neg_edge_types, 0) + yield (pos_batch, neg_batch) class Batcher(object):