import torch from typing import List from .trainprep import PreparedData from dataclasses import dataclass import random from collections import defaultdict @dataclass class TrainingBatch(object): relation_family_index: int relation_type_index: int node_type_row: int node_type_column: int edges: torch.Tensor class FastBatcher(object): def __init__(self, prep_d: PreparedData, batch_size: int) -> None: if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') self.prep_d = prep_d self.batch_size = int(batch_size) self.edges = None self.build() def build(self): self.edges = [] for fam_idx, fam in enumerate(self.prep_d.relation_families): edges = [] targets = [] edges_back = [] targets_back = [] for rel_idx, rel in enumerate(fam.relation_types): edges.append(rel.edges_pos.train) edges.append(rel.edges_neg.train) targets.append(torch.ones(len(rel.edges_pos.train))) targets.append(torch.zeros(len(rel.edges_neg.train))) edges_back.append(rel.edges_back_pos.train) edges_back.append(rel.edges_back_neg.train) targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) edges = torch.cat(edges) targets = torch.cat(targets) edges_back = torch.cat(edges_back) targets_back = torch.cat(targets_back) order = torch.randperm(len(edges)) edges = edges[order] targets = targets[order] order_back = torch.randperm(len(edges_back)) edges_back = edges_back[order_back] targets_back = targets_back[order_back] self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, 'edges': edges, 'targets': targets, 'ofs': 0}) self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, 'edges': edges_back, 'targets': targets_back, 'ofs': 0}) def __iter__(self): while True: edges = [ e for e in self.edges \ if e['ofs'] < len(e['edges']) ] # TODO: need to finish this def __iter_old__(self): edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] offsets = {} orders = {} done = {} for fam_idx, fam in enumerate(self.prep_d.relation_families): for rel_idx, rel in enumerate(fam.relation_types): for et in edge_types: done[fam_idx, rel_idx, et] = False while True: fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() fam = self.prep_d.relation_families[fam_idx] rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() rel = fam.relation_types[rel_idx] et = random.choice(edge_types) edges = getattr(rel, et).train key = (fam_idx, rel_idx, et) if key not in orders: orders[key] = torch.randperm(len(edges)) offsets[key] = 0 ord = orders[key] ofs = offsets[key] nt_row = rel.node_type_row nt_col = rel.node_type_column if 'back' in et: nt_row, nt_col = nt_col, nt_row if ofs < len(edges): offsets[key] += self.batch_size ord = ord[ofs:ofs+self.batch_size] edges = edges[ord] yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) else: done[key] = True for fam in self.prep_d.relation_families: edges = [] for rel in fam.relation_types: edges.append(rel.edges_pos.train) edges.append(rel.edges_back_pos.train) edges.append(rel.edges_neg.train) edges.append(rel.edges_back_neg.train) edges = torch.cat(e) class FastDecLayer(torch.nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, last_layer_repr: List[torch.Tensor], training_batch: TrainingBatch):