| @@ -1,13 +1,72 @@ | |||
| from icosagon.declayer import Predictions | |||
| from .declayer import Predictions | |||
| import torch | |||
| from dataclasses import dataclass | |||
| from .trainprep import PreparedData | |||
| from typing import Tuple | |||
| @dataclass | |||
| class FlatPredictions(object): | |||
| predictions: torch.Tensor | |||
| truth: torch.Tensor | |||
| part_type: str | |||
| def flatten_predictions(pred: Predictions, part_type: str = 'train'): | |||
| if not isinstance(pred, Predictions): | |||
| raise TypeError('pred must be an instance of Predictions') | |||
| if part_type not in ['train', 'val', 'test']: | |||
| raise ValueError('part_type must be set to train, val or test') | |||
| edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||
| ('edges_back_pos', 1), ('edges_back_neg', 0)] | |||
| input = [] | |||
| target = [] | |||
| for fam in pred.relation_families: | |||
| for rel in fam.relation_types: | |||
| for (et, tgt) in edge_types: | |||
| edge_pred = getattr(getattr(rel, et), part_type) | |||
| input.append(edge_pred) | |||
| target.append(torch.ones_like(edge_pred) * tgt) | |||
| input = torch.cat(input) | |||
| target = torch.cat(target) | |||
| return FlatPredictions(input, target, part_type) | |||
| @dataclass | |||
| class BatchIndices(object): | |||
| indices: torch.Tensor | |||
| part_type: str | |||
| def gather_batch_indices(pred: FlatPredictions, | |||
| indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| if not isinstance(pred, FlatPredictions): | |||
| raise TypeError('pred must be an instance of FlatPredictions') | |||
| if not isinstance(indices, BatchIndices): | |||
| raise TypeError('indices must be an instance of BatchIndices') | |||
| if pred.part_type != indices.part_type: | |||
| raise ValueError('part_type must be the same in pred and indices') | |||
| return (pred.predictions[indices.indices], | |||
| pred.truth[indices.indices]) | |||
| class PredictionsBatch(object): | |||
| def __init__(self, pred: Predictions, part_type: str = 'train', | |||
| batch_size: int = 100, shuffle: bool = False) -> None: | |||
| def __init__(self, prep_d: PreparedData, part_type: str = 'train', | |||
| batch_size: int = 100, shuffle: bool = False, | |||
| generator: torch.Generator = None) -> None: | |||
| if not isinstance(pred, Predictions): | |||
| raise TypeError('pred must be an instance of Predictions') | |||
| if not isinstance(prep_d, PreparedData): | |||
| raise TypeError('prep_d must be an instance of PreparedData') | |||
| if part_type not in ['train', 'val', 'test']: | |||
| raise ValueError('part_type must be set to train, val or test') | |||
| @@ -16,32 +75,28 @@ class PredictionsBatch(object): | |||
| shuffle = bool(shuffle) | |||
| self.predictions = pred | |||
| if generator is not None and not isinstance(generator, torch.Generator): | |||
| raise TypeError('generator must be an instance of torch.Generator') | |||
| self.prep_d = prep_d | |||
| self.part_type = part_type | |||
| self.batch_size = batch_size | |||
| self.shuffle = shuffle | |||
| self.generator = generator or torch.default_generator | |||
| def __iter__(self): | |||
| edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||
| ('edges_back_pos', 1), ('edges_back_neg', 0)] | |||
| input = [] | |||
| target = [] | |||
| for fam in self.predictions.relation_families: | |||
| count = 0 | |||
| for fam in prep_d.relation_families: | |||
| for rel in fam.relation_types: | |||
| for (et, tgt) in edge_types: | |||
| edge_pred = getattr(getattr(rel, et), self.part_type) | |||
| input.append(edge_pred) | |||
| target.append(torch.ones_like(edge_pred) * tgt) | |||
| input = torch.cat(input) | |||
| target = torch.cat(target) | |||
| for et in ['edges_pos', 'edges_neg', | |||
| 'edges_back_pos', 'edges_back_neg']: | |||
| count += len(getattr(getattr(rel, et), part_type)) | |||
| self.total_edge_count = count | |||
| def __iter__(self): | |||
| values = torch.arange(self.total_edge_count) | |||
| if self.shuffle: | |||
| perm = torch.randperm(len(input)) | |||
| input = input[perm] | |||
| target = target[perm] | |||
| perm = torch.randperm(len(values)) | |||
| values = values[perm] | |||
| for i in range(0, len(input), self.batch_size): | |||
| yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) | |||
| for i in range(0, len(values), self.batch_size): | |||
| yield BatchIndices(values[i:i+self.batch_size], self.part_type) | |||
| @@ -1,6 +1,8 @@ | |||
| from .model import Model | |||
| import torch | |||
| from .batch import PredictionsBatch | |||
| from .batch import PredictionsBatch, \ | |||
| flatten_predictions, \ | |||
| gather_batch_indices | |||
| from typing import Callable | |||
| from types import FunctionType | |||
| @@ -9,7 +11,7 @@ class TrainLoop(object): | |||
| def __init__(self, model: Model, lr: float = 0.001, | |||
| loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||
| torch.nn.functional.binary_cross_entropy_with_logits, | |||
| batch_size: int = 100) -> None: | |||
| batch_size: int = 100, generator: torch.Generator = None) -> None: | |||
| if not isinstance(model, Model): | |||
| raise TypeError('model must be an instance of Model') | |||
| @@ -21,10 +23,14 @@ class TrainLoop(object): | |||
| batch_size = int(batch_size) | |||
| if generator is not None and not isinstance(generator, torch.Generator): | |||
| raise TypeError('generator must be an instance of torch.Generator') | |||
| self.model = model | |||
| self.lr = lr | |||
| self.loss = loss | |||
| self.batch_size = batch_size | |||
| self.generator = generator or torch.default_generator | |||
| self.opt = None | |||
| @@ -35,22 +41,25 @@ class TrainLoop(object): | |||
| self.opt = opt | |||
| def run_epoch(self): | |||
| pred = self.model(None) | |||
| batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||
| n = len(list(iter(batch))) | |||
| batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | |||
| generator=self.generator) | |||
| # pred = self.model(None) | |||
| # n = len(list(iter(batch))) | |||
| loss_sum = 0 | |||
| for i in range(n): | |||
| for indices in batch: | |||
| self.opt.zero_grad() | |||
| pred = self.model(None) | |||
| batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
| seed = torch.rand(1).item() | |||
| rng_state = torch.get_rng_state() | |||
| torch.manual_seed(seed) | |||
| it = iter(batch) | |||
| torch.set_rng_state(rng_state) | |||
| for k in range(i): | |||
| _ = next(it) | |||
| (input, target) = next(it) | |||
| pred = flatten_predictions(pred) | |||
| # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
| # seed = torch.rand(1).item() | |||
| # rng_state = torch.get_rng_state() | |||
| # torch.manual_seed(seed) | |||
| #it = iter(batch) | |||
| #torch.set_rng_state(rng_state) | |||
| #for k in range(i): | |||
| #_ = next(it) | |||
| #(input, target) = next(it) | |||
| (input, target) = gather_batch_indices(pred, indices) | |||
| loss = self.loss(input, target) | |||
| loss.backward() | |||
| self.opt.step() | |||
| @@ -1,4 +1,8 @@ | |||
| from icosagon.batch import PredictionsBatch | |||
| from icosagon.batch import PredictionsBatch, \ | |||
| FlatPredictions, \ | |||
| flatten_predictions, \ | |||
| BatchIndices, \ | |||
| gather_batch_indices | |||
| from icosagon.declayer import Predictions, \ | |||
| RelationPredictions, \ | |||
| RelationFamilyPredictions | |||
| @@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \ | |||
| TrainValTest | |||
| from icosagon.data import Data | |||
| import torch | |||
| import pytest | |||
| def test_flat_predictions_01(): | |||
| pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), | |||
| torch.tensor([1, 0, 1, 0]), 'train') | |||
| assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) | |||
| assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) | |||
| assert pred.part_type == 'train' | |||
| def test_flatten_predictions_01(): | |||
| rel_pred = RelationPredictions( | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
| ) | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| pred_flat = flatten_predictions(pred, part_type='train') | |||
| assert torch.all(pred_flat.predictions == \ | |||
| torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) | |||
| assert torch.all(pred_flat.truth == \ | |||
| torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) | |||
| assert pred_flat.part_type == 'train' | |||
| def test_flatten_predictions_02(): | |||
| rel_pred = RelationPredictions( | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
| ) | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| pred_flat = flatten_predictions(pred, part_type='val') | |||
| assert len(pred_flat.predictions) == 0 | |||
| assert len(pred_flat.truth) == 0 | |||
| assert pred_flat.part_type == 'val' | |||
| def test_flatten_predictions_03(): | |||
| rel_pred = RelationPredictions( | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
| ) | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| pred_flat = flatten_predictions(pred, part_type='test') | |||
| assert len(pred_flat.predictions) == 0 | |||
| assert len(pred_flat.truth) == 0 | |||
| assert pred_flat.part_type == 'test' | |||
| def test_flatten_predictions_04(): | |||
| rel_pred = RelationPredictions( | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
| ) | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| with pytest.raises(TypeError): | |||
| pred_flat = flatten_predictions(1, part_type='test') | |||
| with pytest.raises(ValueError): | |||
| pred_flat = flatten_predictions(pred, part_type='x') | |||
| def test_batch_indices_01(): | |||
| indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | |||
| assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | |||
| assert indices.part_type == 'train' | |||
| def test_gather_batch_indices_01(): | |||
| rel_pred = RelationPredictions( | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
| ) | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| pred_flat = flatten_predictions(pred, part_type='train') | |||
| indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') | |||
| (input, target) = gather_batch_indices(pred_flat, indices) | |||
| assert torch.all(input == \ | |||
| torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) | |||
| assert torch.all(target == \ | |||
| torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) | |||
| def test_predictions_batch_01(): | |||
| @@ -38,10 +149,13 @@ def test_predictions_batch_01(): | |||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
| pred = Predictions([ fam_pred ]) | |||
| batch = PredictionsBatch(pred, part_type='train', batch_size=1) | |||
| pred_flat = flatten_predictions(pred, part_type='train') | |||
| batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) | |||
| count = 0 | |||
| lst = [] | |||
| for (input, target) in batch: | |||
| for indices in batch: | |||
| (input, target) = gather_batch_indices(pred_flat, indices) | |||
| assert len(input) == 1 | |||
| assert len(target) == 1 | |||
| lst.append((input[0], target[0])) | |||