diff --git a/src/icosagon/batch.py b/src/icosagon/batch.py index 9a28712..e395f3f 100644 --- a/src/icosagon/batch.py +++ b/src/icosagon/batch.py @@ -1,13 +1,72 @@ -from icosagon.declayer import Predictions +from .declayer import Predictions import torch +from dataclasses import dataclass +from .trainprep import PreparedData +from typing import Tuple + + +@dataclass +class FlatPredictions(object): + predictions: torch.Tensor + truth: torch.Tensor + part_type: str + + +def flatten_predictions(pred: Predictions, part_type: str = 'train'): + if not isinstance(pred, Predictions): + raise TypeError('pred must be an instance of Predictions') + + if part_type not in ['train', 'val', 'test']: + raise ValueError('part_type must be set to train, val or test') + + edge_types = [('edges_pos', 1), ('edges_neg', 0), + ('edges_back_pos', 1), ('edges_back_neg', 0)] + + input = [] + target = [] + + for fam in pred.relation_families: + for rel in fam.relation_types: + for (et, tgt) in edge_types: + edge_pred = getattr(getattr(rel, et), part_type) + input.append(edge_pred) + target.append(torch.ones_like(edge_pred) * tgt) + + input = torch.cat(input) + target = torch.cat(target) + + return FlatPredictions(input, target, part_type) + + +@dataclass +class BatchIndices(object): + indices: torch.Tensor + part_type: str + + +def gather_batch_indices(pred: FlatPredictions, + indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: + + if not isinstance(pred, FlatPredictions): + raise TypeError('pred must be an instance of FlatPredictions') + + if not isinstance(indices, BatchIndices): + raise TypeError('indices must be an instance of BatchIndices') + + if pred.part_type != indices.part_type: + raise ValueError('part_type must be the same in pred and indices') + + return (pred.predictions[indices.indices], + pred.truth[indices.indices]) class PredictionsBatch(object): - def __init__(self, pred: Predictions, part_type: str = 'train', - batch_size: int = 100, shuffle: bool = False) -> None: + def __init__(self, prep_d: PreparedData, part_type: str = 'train', + batch_size: int = 100, shuffle: bool = False, + generator: torch.Generator = None) -> None: - if not isinstance(pred, Predictions): - raise TypeError('pred must be an instance of Predictions') + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instance of PreparedData') if part_type not in ['train', 'val', 'test']: raise ValueError('part_type must be set to train, val or test') @@ -16,32 +75,28 @@ class PredictionsBatch(object): shuffle = bool(shuffle) - self.predictions = pred + if generator is not None and not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + + self.prep_d = prep_d self.part_type = part_type self.batch_size = batch_size self.shuffle = shuffle + self.generator = generator or torch.default_generator - def __iter__(self): - edge_types = [('edges_pos', 1), ('edges_neg', 0), - ('edges_back_pos', 1), ('edges_back_neg', 0)] - - input = [] - target = [] - - for fam in self.predictions.relation_families: + count = 0 + for fam in prep_d.relation_families: for rel in fam.relation_types: - for (et, tgt) in edge_types: - edge_pred = getattr(getattr(rel, et), self.part_type) - input.append(edge_pred) - target.append(torch.ones_like(edge_pred) * tgt) - - input = torch.cat(input) - target = torch.cat(target) + for et in ['edges_pos', 'edges_neg', + 'edges_back_pos', 'edges_back_neg']: + count += len(getattr(getattr(rel, et), part_type)) + self.total_edge_count = count + def __iter__(self): + values = torch.arange(self.total_edge_count) if self.shuffle: - perm = torch.randperm(len(input)) - input = input[perm] - target = target[perm] + perm = torch.randperm(len(values)) + values = values[perm] - for i in range(0, len(input), self.batch_size): - yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) + for i in range(0, len(values), self.batch_size): + yield BatchIndices(values[i:i+self.batch_size], self.part_type) diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py index 2ce8240..5492878 100644 --- a/src/icosagon/trainloop.py +++ b/src/icosagon/trainloop.py @@ -1,6 +1,8 @@ from .model import Model import torch -from .batch import PredictionsBatch +from .batch import PredictionsBatch, \ + flatten_predictions, \ + gather_batch_indices from typing import Callable from types import FunctionType @@ -9,7 +11,7 @@ class TrainLoop(object): def __init__(self, model: Model, lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, - batch_size: int = 100) -> None: + batch_size: int = 100, generator: torch.Generator = None) -> None: if not isinstance(model, Model): raise TypeError('model must be an instance of Model') @@ -21,10 +23,14 @@ class TrainLoop(object): batch_size = int(batch_size) + if generator is not None and not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + self.model = model self.lr = lr self.loss = loss self.batch_size = batch_size + self.generator = generator or torch.default_generator self.opt = None @@ -35,22 +41,25 @@ class TrainLoop(object): self.opt = opt def run_epoch(self): - pred = self.model(None) - batch = PredictionsBatch(pred, batch_size=self.batch_size) - n = len(list(iter(batch))) + batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, + generator=self.generator) + # pred = self.model(None) + # n = len(list(iter(batch))) loss_sum = 0 - for i in range(n): + for indices in batch: self.opt.zero_grad() pred = self.model(None) - batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) - seed = torch.rand(1).item() - rng_state = torch.get_rng_state() - torch.manual_seed(seed) - it = iter(batch) - torch.set_rng_state(rng_state) - for k in range(i): - _ = next(it) - (input, target) = next(it) + pred = flatten_predictions(pred) + # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) + # seed = torch.rand(1).item() + # rng_state = torch.get_rng_state() + # torch.manual_seed(seed) + #it = iter(batch) + #torch.set_rng_state(rng_state) + #for k in range(i): + #_ = next(it) + #(input, target) = next(it) + (input, target) = gather_batch_indices(pred, indices) loss = self.loss(input, target) loss.backward() self.opt.step() diff --git a/tests/icosagon/test_batch.py b/tests/icosagon/test_batch.py index 3d185e4..b6cd6d6 100644 --- a/tests/icosagon/test_batch.py +++ b/tests/icosagon/test_batch.py @@ -1,4 +1,8 @@ -from icosagon.batch import PredictionsBatch +from icosagon.batch import PredictionsBatch, \ + FlatPredictions, \ + flatten_predictions, \ + BatchIndices, \ + gather_batch_indices from icosagon.declayer import Predictions, \ RelationPredictions, \ RelationFamilyPredictions @@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \ TrainValTest from icosagon.data import Data import torch +import pytest + + +def test_flat_predictions_01(): + pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), + torch.tensor([1, 0, 1, 0]), 'train') + + assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) + assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) + assert pred.part_type == 'train' + + +def test_flatten_predictions_01(): + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + pred_flat = flatten_predictions(pred, part_type='train') + + assert torch.all(pred_flat.predictions == \ + torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) + assert torch.all(pred_flat.truth == \ + torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) + assert pred_flat.part_type == 'train' + + +def test_flatten_predictions_02(): + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + pred_flat = flatten_predictions(pred, part_type='val') + + assert len(pred_flat.predictions) == 0 + assert len(pred_flat.truth) == 0 + assert pred_flat.part_type == 'val' + + +def test_flatten_predictions_03(): + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + pred_flat = flatten_predictions(pred, part_type='test') + + assert len(pred_flat.predictions) == 0 + assert len(pred_flat.truth) == 0 + assert pred_flat.part_type == 'test' + + +def test_flatten_predictions_04(): + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + with pytest.raises(TypeError): + pred_flat = flatten_predictions(1, part_type='test') + + with pytest.raises(ValueError): + pred_flat = flatten_predictions(pred, part_type='x') + + +def test_batch_indices_01(): + indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') + assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) + assert indices.part_type == 'train' + + +def test_gather_batch_indices_01(): + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + pred_flat = flatten_predictions(pred, part_type='train') + + indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') + + (input, target) = gather_batch_indices(pred_flat, indices) + assert torch.all(input == \ + torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) + assert torch.all(target == \ + torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) def test_predictions_batch_01(): @@ -38,10 +149,13 @@ def test_predictions_batch_01(): fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) - batch = PredictionsBatch(pred, part_type='train', batch_size=1) + pred_flat = flatten_predictions(pred, part_type='train') + + batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) count = 0 lst = [] - for (input, target) in batch: + for indices in batch: + (input, target) = gather_batch_indices(pred_flat, indices) assert len(input) == 1 assert len(target) == 1 lst.append((input[0], target[0]))