|
- from icosagon.declayer import Predictions
- import torch
-
-
- class PredictionsBatch(object):
- def __init__(self, pred: Predictions, part_type: str = 'train',
- batch_size: int = 100) -> None:
-
- if not isinstance(pred, Predictions):
- raise TypeError('pred must be an instance of Predictions')
-
- if part_type not in ['train', 'val', 'test']:
- raise ValueError('part_type must be set to train, val or test')
-
- batch_size = int(batch_size)
-
- self.predictions = pred
- self.part_type = part_type
- self.batch_size = batch_size
-
- def __iter__(self):
- edge_types = [('edges_pos', 1), ('edges_neg', 0),
- ('edges_back_pos', 1), ('edges_back_neg', 0)]
-
- input = []
- target = []
-
- for fam in self.predictions.relation_families:
- for rel in fam.relation_types:
- for (et, tgt) in edge_types:
- edge_pred = getattr(getattr(rel, et), self.part_type)
- input.append(edge_pred)
- target.append(torch.ones_like(edge_pred) * tgt)
-
- input = torch.cat(input)
- target = torch.cat(target)
-
- for i in range(0, len(input), self.batch_size):
- yield (input[i:i+self.batch_size], target[i:i+self.batch_size])
|