from .declayer import Predictions import torch from dataclasses import dataclass from .trainprep import PreparedData from typing import Tuple @dataclass class FlatPredictions(object): predictions: torch.Tensor truth: torch.Tensor part_type: str def flatten_predictions(pred: Predictions, part_type: str = 'train'): if not isinstance(pred, Predictions): raise TypeError('pred must be an instance of Predictions') if part_type not in ['train', 'val', 'test']: raise ValueError('part_type must be set to train, val or test') edge_types = [('edges_pos', 1), ('edges_neg', 0), ('edges_back_pos', 1), ('edges_back_neg', 0)] input = [] target = [] for fam in pred.relation_families: for rel in fam.relation_types: for (et, tgt) in edge_types: edge_pred = getattr(getattr(rel, et), part_type) input.append(edge_pred) target.append(torch.ones_like(edge_pred) * tgt) input = torch.cat(input) target = torch.cat(target) return FlatPredictions(input, target, part_type) @dataclass class BatchIndices(object): indices: torch.Tensor part_type: str def gather_batch_indices(pred: FlatPredictions, indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(pred, FlatPredictions): raise TypeError('pred must be an instance of FlatPredictions') if not isinstance(indices, BatchIndices): raise TypeError('indices must be an instance of BatchIndices') if pred.part_type != indices.part_type: raise ValueError('part_type must be the same in pred and indices') return (pred.predictions[indices.indices], pred.truth[indices.indices]) class PredictionsBatch(object): def __init__(self, prep_d: PreparedData, part_type: str = 'train', batch_size: int = 100, shuffle: bool = False, generator: torch.Generator = None) -> None: if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') if part_type not in ['train', 'val', 'test']: raise ValueError('part_type must be set to train, val or test') batch_size = int(batch_size) shuffle = bool(shuffle) if generator is not None and not isinstance(generator, torch.Generator): raise TypeError('generator must be an instance of torch.Generator') self.prep_d = prep_d self.part_type = part_type self.batch_size = batch_size self.shuffle = shuffle self.generator = generator or torch.default_generator count = 0 for fam in prep_d.relation_families: for rel in fam.relation_types: for et in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: count += len(getattr(getattr(rel, et), part_type)) self.total_edge_count = count def __iter__(self): values = torch.arange(self.total_edge_count) if self.shuffle: perm = torch.randperm(len(values)) values = values[perm] for i in range(0, len(values), self.batch_size): yield BatchIndices(values[i:i+self.batch_size], self.part_type)