|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from .declayer import Predictions
- import torch
- from dataclasses import dataclass
- from .trainprep import PreparedData
- from typing import Tuple
-
-
- @dataclass
- class FlatPredictions(object):
- predictions: torch.Tensor
- truth: torch.Tensor
- part_type: str
-
-
- def flatten_predictions(pred: Predictions, part_type: str = 'train'):
- if not isinstance(pred, Predictions):
- raise TypeError('pred must be an instance of Predictions')
-
- if part_type not in ['train', 'val', 'test']:
- raise ValueError('part_type must be set to train, val or test')
-
- edge_types = [('edges_pos', 1), ('edges_neg', 0),
- ('edges_back_pos', 1), ('edges_back_neg', 0)]
-
- input = []
- target = []
-
- for fam in pred.relation_families:
- for rel in fam.relation_types:
- for (et, tgt) in edge_types:
- edge_pred = getattr(getattr(rel, et), part_type)
- input.append(edge_pred)
- target.append(torch.ones_like(edge_pred) * tgt)
-
- input = torch.cat(input)
- target = torch.cat(target)
-
- return FlatPredictions(input, target, part_type)
-
-
- @dataclass
- class BatchIndices(object):
- indices: torch.Tensor
- part_type: str
-
-
- def gather_batch_indices(pred: FlatPredictions,
- indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]:
-
- if not isinstance(pred, FlatPredictions):
- raise TypeError('pred must be an instance of FlatPredictions')
-
- if not isinstance(indices, BatchIndices):
- raise TypeError('indices must be an instance of BatchIndices')
-
- if pred.part_type != indices.part_type:
- raise ValueError('part_type must be the same in pred and indices')
-
- return (pred.predictions[indices.indices],
- pred.truth[indices.indices])
-
-
- class PredictionsBatch(object):
- def __init__(self, prep_d: PreparedData, part_type: str = 'train',
- batch_size: int = 100, shuffle: bool = False,
- generator: torch.Generator = None) -> None:
-
- if not isinstance(prep_d, PreparedData):
- raise TypeError('prep_d must be an instance of PreparedData')
-
- if part_type not in ['train', 'val', 'test']:
- raise ValueError('part_type must be set to train, val or test')
-
- batch_size = int(batch_size)
-
- shuffle = bool(shuffle)
-
- if generator is not None and not isinstance(generator, torch.Generator):
- raise TypeError('generator must be an instance of torch.Generator')
-
- self.prep_d = prep_d
- self.part_type = part_type
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.generator = generator or torch.default_generator
-
- count = 0
- for fam in prep_d.relation_families:
- for rel in fam.relation_types:
- for et in ['edges_pos', 'edges_neg',
- 'edges_back_pos', 'edges_back_neg']:
- count += len(getattr(getattr(rel, et), part_type))
- self.total_edge_count = count
-
- def __iter__(self):
- values = torch.arange(self.total_edge_count)
- if self.shuffle:
- perm = torch.randperm(len(values))
- values = values[perm]
-
- for i in range(0, len(values), self.batch_size):
- yield BatchIndices(values[i:i+self.batch_size], self.part_type)
|