from icosagon.declayer import Predictions import torch class PredictionsBatch(object): def __init__(self, pred: Predictions, part_type: str = 'train', batch_size: int = 100) -> None: if not isinstance(pred, Predictions): raise TypeError('pred must be an instance of Predictions') if part_type not in ['train', 'val', 'test']: raise ValueError('part_type must be set to train, val or test') batch_size = int(batch_size) self.predictions = pred self.part_type = part_type self.batch_size = batch_size def __iter__(self): edge_types = [('edges_pos', 1), ('edges_neg', 0), ('edges_back_pos', 1), ('edges_back_neg', 0)] input = [] target = [] for fam in self.predictions.relation_families: for rel in fam.relation_types: for (et, tgt) in edge_types: edge_pred = getattr(getattr(rel, et), self.part_type) input.append(edge_pred) target.append(torch.ones_like(edge_pred) * tgt) input = torch.cat(input) target = torch.cat(target) for i in range(0, len(input), self.batch_size): yield (input[i:i+self.batch_size], target[i:i+self.batch_size])