from .data import Data from .model import TrainingBatch import torch from functools import reduce def _shuffle(x: torch.Tensor) -> torch.Tensor: order = torch.randperm(len(x)) return x[order] def _same_data_org(pos_data: Data, neg_data: Data): if len(pos_data.vertex_types) != len(neg_data.vertex_types): return False test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \ and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \ for i in range(len(pos_data.vertex_types)) ] if not all(test): return False if not set(pos_data.edge_types.keys()) == \ set(neg_data.edge_types.keys()): return False test = [ pos_data.edge_types[i].name == \ neg_data.edge_types[i].name \ and pos_data.edge_types[i].vertex_type_row == \ neg_data.edge_types[i].vertex_type_row \ and pos_data.edge_types[i].vertex_type_column == \ neg_data.edge_types[i].vertex_type_column \ and len(pos_data.edge_types[i].adjacency_matrices) == \ len(neg_data.edge_types[i].adjacency_matrices) \ for i in pos_data.edge_types.keys() ] if not all(test): return False test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \ len(neg_data.edge_types[i].adjacency_matrices[k].values()) \ for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \ for i in pos_data.edge_types.keys() ] test = reduce(list.__add__, test, []) if not all(test): return False return True class DualBatcher(object): def __init__(self, pos_data: Data, neg_data: Data, batch_size: int=512, shuffle: bool=True) -> None: if not isinstance(pos_data, Data): raise TypeError('pos_data must be an instance of Data') if not isinstance(neg_data, Data): raise TypeError('neg_data must be an instance of Data') if not _same_data_org(pos_data, neg_data): raise ValueError('pos_data and neg_data must have the same organization') self.pos_data = pos_data self.neg_data = neg_data self.batch_size = int(batch_size) self.shuffle = bool(shuffle) def get_edge_lists(self, data: Data): edge_types = list(data.edge_types.items()) edge_keys = [ a[0] for a in edge_types ] edge_types = [ a[1] for a in edge_types ] edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ for adj_mat in et.adjacency_matrices ] \ for et in edge_types ] if self.shuffle: edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ for edge_lst in edge_lists ] offsets = [ [ 0 ] * len(et.adjacency_matrices) \ for et in edge_types ] return (edge_keys, edge_types, edge_lists, offsets) def get_candidates(self, edge_lists, offsets): candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] if len(candidates) == 0: return None, None edge_idx = torch.randint(0, len(candidates), (1,)).item() edge_idx = candidates[edge_idx] candidates = [ rel_idx \ for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] rel_idx = torch.randint(0, len(candidates), (1,)).item() rel_idx = candidates[rel_idx] return edge_idx, rel_idx def take_edges(self, edge_idx, rel_idx, edge_lists, offsets, edge_types, target_value): lst = edge_lists[edge_idx][rel_idx] et = edge_types[edge_idx] ofs = offsets[edge_idx][rel_idx] lst = lst[ofs:ofs+self.batch_size] offsets[edge_idx][rel_idx] += self.batch_size res = TrainingBatch(et.vertex_type_row, et.vertex_type_column, rel_idx, lst, torch.full(( len(lst), ), target_value, dtype=torch.float32)) return res def __iter__(self): pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \ self.get_edge_lists(self.pos_data) neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \ self.get_edge_lists(self.neg_data) while True: edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets) if edge_idx is None: return pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists, pos_offsets, pos_edge_types, 1) neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists, neg_offsets, neg_edge_types, 0) yield (pos_batch, neg_batch) class Batcher(object): def __init__(self, data: Data, batch_size: int=512, shuffle: bool=True) -> None: if not isinstance(data, Data): raise TypeError('data must be an instance of Data') self.data = data self.batch_size = int(batch_size) self.shuffle = bool(shuffle) def __iter__(self) -> TrainingBatch: edge_types = list(self.data.edge_types.values()) edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ for adj_mat in et.adjacency_matrices ] \ for et in edge_types ] if self.shuffle: edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ for edge_lst in edge_lists ] offsets = [ [ 0 ] * len(et.adjacency_matrices) \ for et in edge_types ] while True: candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] if len(candidates) == 0: break edge_idx = torch.randint(0, len(candidates), (1,)).item() edge_idx = candidates[edge_idx] candidates = [ rel_idx \ for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] rel_idx = torch.randint(0, len(candidates), (1,)).item() rel_idx = candidates[rel_idx] lst = edge_lists[edge_idx][rel_idx] et = edge_types[edge_idx] ofs = offsets[edge_idx][rel_idx] lst = lst[ofs:ofs+self.batch_size] offsets[edge_idx][rel_idx] += self.batch_size b = TrainingBatch(et.vertex_type_row, et.vertex_type_column, rel_idx, lst, torch.full((len(lst),), self.data.target_value, dtype=torch.float32)) yield b