| @@ -1,13 +1,72 @@ | |||||
| from icosagon.declayer import Predictions | |||||
| from .declayer import Predictions | |||||
| import torch | import torch | ||||
| from dataclasses import dataclass | |||||
| from .trainprep import PreparedData | |||||
| from typing import Tuple | |||||
| @dataclass | |||||
| class FlatPredictions(object): | |||||
| predictions: torch.Tensor | |||||
| truth: torch.Tensor | |||||
| part_type: str | |||||
| def flatten_predictions(pred: Predictions, part_type: str = 'train'): | |||||
| if not isinstance(pred, Predictions): | |||||
| raise TypeError('pred must be an instance of Predictions') | |||||
| if part_type not in ['train', 'val', 'test']: | |||||
| raise ValueError('part_type must be set to train, val or test') | |||||
| edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||||
| ('edges_back_pos', 1), ('edges_back_neg', 0)] | |||||
| input = [] | |||||
| target = [] | |||||
| for fam in pred.relation_families: | |||||
| for rel in fam.relation_types: | |||||
| for (et, tgt) in edge_types: | |||||
| edge_pred = getattr(getattr(rel, et), part_type) | |||||
| input.append(edge_pred) | |||||
| target.append(torch.ones_like(edge_pred) * tgt) | |||||
| input = torch.cat(input) | |||||
| target = torch.cat(target) | |||||
| return FlatPredictions(input, target, part_type) | |||||
| @dataclass | |||||
| class BatchIndices(object): | |||||
| indices: torch.Tensor | |||||
| part_type: str | |||||
| def gather_batch_indices(pred: FlatPredictions, | |||||
| indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| if not isinstance(pred, FlatPredictions): | |||||
| raise TypeError('pred must be an instance of FlatPredictions') | |||||
| if not isinstance(indices, BatchIndices): | |||||
| raise TypeError('indices must be an instance of BatchIndices') | |||||
| if pred.part_type != indices.part_type: | |||||
| raise ValueError('part_type must be the same in pred and indices') | |||||
| return (pred.predictions[indices.indices], | |||||
| pred.truth[indices.indices]) | |||||
| class PredictionsBatch(object): | class PredictionsBatch(object): | ||||
| def __init__(self, pred: Predictions, part_type: str = 'train', | |||||
| batch_size: int = 100, shuffle: bool = False) -> None: | |||||
| def __init__(self, prep_d: PreparedData, part_type: str = 'train', | |||||
| batch_size: int = 100, shuffle: bool = False, | |||||
| generator: torch.Generator = None) -> None: | |||||
| if not isinstance(pred, Predictions): | |||||
| raise TypeError('pred must be an instance of Predictions') | |||||
| if not isinstance(prep_d, PreparedData): | |||||
| raise TypeError('prep_d must be an instance of PreparedData') | |||||
| if part_type not in ['train', 'val', 'test']: | if part_type not in ['train', 'val', 'test']: | ||||
| raise ValueError('part_type must be set to train, val or test') | raise ValueError('part_type must be set to train, val or test') | ||||
| @@ -16,32 +75,28 @@ class PredictionsBatch(object): | |||||
| shuffle = bool(shuffle) | shuffle = bool(shuffle) | ||||
| self.predictions = pred | |||||
| if generator is not None and not isinstance(generator, torch.Generator): | |||||
| raise TypeError('generator must be an instance of torch.Generator') | |||||
| self.prep_d = prep_d | |||||
| self.part_type = part_type | self.part_type = part_type | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| self.generator = generator or torch.default_generator | |||||
| def __iter__(self): | |||||
| edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||||
| ('edges_back_pos', 1), ('edges_back_neg', 0)] | |||||
| input = [] | |||||
| target = [] | |||||
| for fam in self.predictions.relation_families: | |||||
| count = 0 | |||||
| for fam in prep_d.relation_families: | |||||
| for rel in fam.relation_types: | for rel in fam.relation_types: | ||||
| for (et, tgt) in edge_types: | |||||
| edge_pred = getattr(getattr(rel, et), self.part_type) | |||||
| input.append(edge_pred) | |||||
| target.append(torch.ones_like(edge_pred) * tgt) | |||||
| input = torch.cat(input) | |||||
| target = torch.cat(target) | |||||
| for et in ['edges_pos', 'edges_neg', | |||||
| 'edges_back_pos', 'edges_back_neg']: | |||||
| count += len(getattr(getattr(rel, et), part_type)) | |||||
| self.total_edge_count = count | |||||
| def __iter__(self): | |||||
| values = torch.arange(self.total_edge_count) | |||||
| if self.shuffle: | if self.shuffle: | ||||
| perm = torch.randperm(len(input)) | |||||
| input = input[perm] | |||||
| target = target[perm] | |||||
| perm = torch.randperm(len(values)) | |||||
| values = values[perm] | |||||
| for i in range(0, len(input), self.batch_size): | |||||
| yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) | |||||
| for i in range(0, len(values), self.batch_size): | |||||
| yield BatchIndices(values[i:i+self.batch_size], self.part_type) | |||||
| @@ -1,6 +1,8 @@ | |||||
| from .model import Model | from .model import Model | ||||
| import torch | import torch | ||||
| from .batch import PredictionsBatch | |||||
| from .batch import PredictionsBatch, \ | |||||
| flatten_predictions, \ | |||||
| gather_batch_indices | |||||
| from typing import Callable | from typing import Callable | ||||
| from types import FunctionType | from types import FunctionType | ||||
| @@ -9,7 +11,7 @@ class TrainLoop(object): | |||||
| def __init__(self, model: Model, lr: float = 0.001, | def __init__(self, model: Model, lr: float = 0.001, | ||||
| loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | ||||
| torch.nn.functional.binary_cross_entropy_with_logits, | torch.nn.functional.binary_cross_entropy_with_logits, | ||||
| batch_size: int = 100) -> None: | |||||
| batch_size: int = 100, generator: torch.Generator = None) -> None: | |||||
| if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
| raise TypeError('model must be an instance of Model') | raise TypeError('model must be an instance of Model') | ||||
| @@ -21,10 +23,14 @@ class TrainLoop(object): | |||||
| batch_size = int(batch_size) | batch_size = int(batch_size) | ||||
| if generator is not None and not isinstance(generator, torch.Generator): | |||||
| raise TypeError('generator must be an instance of torch.Generator') | |||||
| self.model = model | self.model = model | ||||
| self.lr = lr | self.lr = lr | ||||
| self.loss = loss | self.loss = loss | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.generator = generator or torch.default_generator | |||||
| self.opt = None | self.opt = None | ||||
| @@ -35,22 +41,25 @@ class TrainLoop(object): | |||||
| self.opt = opt | self.opt = opt | ||||
| def run_epoch(self): | def run_epoch(self): | ||||
| pred = self.model(None) | |||||
| batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||||
| n = len(list(iter(batch))) | |||||
| batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | |||||
| generator=self.generator) | |||||
| # pred = self.model(None) | |||||
| # n = len(list(iter(batch))) | |||||
| loss_sum = 0 | loss_sum = 0 | ||||
| for i in range(n): | |||||
| for indices in batch: | |||||
| self.opt.zero_grad() | self.opt.zero_grad() | ||||
| pred = self.model(None) | pred = self.model(None) | ||||
| batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||||
| seed = torch.rand(1).item() | |||||
| rng_state = torch.get_rng_state() | |||||
| torch.manual_seed(seed) | |||||
| it = iter(batch) | |||||
| torch.set_rng_state(rng_state) | |||||
| for k in range(i): | |||||
| _ = next(it) | |||||
| (input, target) = next(it) | |||||
| pred = flatten_predictions(pred) | |||||
| # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||||
| # seed = torch.rand(1).item() | |||||
| # rng_state = torch.get_rng_state() | |||||
| # torch.manual_seed(seed) | |||||
| #it = iter(batch) | |||||
| #torch.set_rng_state(rng_state) | |||||
| #for k in range(i): | |||||
| #_ = next(it) | |||||
| #(input, target) = next(it) | |||||
| (input, target) = gather_batch_indices(pred, indices) | |||||
| loss = self.loss(input, target) | loss = self.loss(input, target) | ||||
| loss.backward() | loss.backward() | ||||
| self.opt.step() | self.opt.step() | ||||
| @@ -1,4 +1,8 @@ | |||||
| from icosagon.batch import PredictionsBatch | |||||
| from icosagon.batch import PredictionsBatch, \ | |||||
| FlatPredictions, \ | |||||
| flatten_predictions, \ | |||||
| BatchIndices, \ | |||||
| gather_batch_indices | |||||
| from icosagon.declayer import Predictions, \ | from icosagon.declayer import Predictions, \ | ||||
| RelationPredictions, \ | RelationPredictions, \ | ||||
| RelationFamilyPredictions | RelationFamilyPredictions | ||||
| @@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \ | |||||
| TrainValTest | TrainValTest | ||||
| from icosagon.data import Data | from icosagon.data import Data | ||||
| import torch | import torch | ||||
| import pytest | |||||
| def test_flat_predictions_01(): | |||||
| pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), | |||||
| torch.tensor([1, 0, 1, 0]), 'train') | |||||
| assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) | |||||
| assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) | |||||
| assert pred.part_type == 'train' | |||||
| def test_flatten_predictions_01(): | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| pred_flat = flatten_predictions(pred, part_type='train') | |||||
| assert torch.all(pred_flat.predictions == \ | |||||
| torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) | |||||
| assert torch.all(pred_flat.truth == \ | |||||
| torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) | |||||
| assert pred_flat.part_type == 'train' | |||||
| def test_flatten_predictions_02(): | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| pred_flat = flatten_predictions(pred, part_type='val') | |||||
| assert len(pred_flat.predictions) == 0 | |||||
| assert len(pred_flat.truth) == 0 | |||||
| assert pred_flat.part_type == 'val' | |||||
| def test_flatten_predictions_03(): | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| pred_flat = flatten_predictions(pred, part_type='test') | |||||
| assert len(pred_flat.predictions) == 0 | |||||
| assert len(pred_flat.truth) == 0 | |||||
| assert pred_flat.part_type == 'test' | |||||
| def test_flatten_predictions_04(): | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| with pytest.raises(TypeError): | |||||
| pred_flat = flatten_predictions(1, part_type='test') | |||||
| with pytest.raises(ValueError): | |||||
| pred_flat = flatten_predictions(pred, part_type='x') | |||||
| def test_batch_indices_01(): | |||||
| indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | |||||
| assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | |||||
| assert indices.part_type == 'train' | |||||
| def test_gather_batch_indices_01(): | |||||
| rel_pred = RelationPredictions( | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
| TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
| ) | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
| pred = Predictions([ fam_pred ]) | |||||
| pred_flat = flatten_predictions(pred, part_type='train') | |||||
| indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') | |||||
| (input, target) = gather_batch_indices(pred_flat, indices) | |||||
| assert torch.all(input == \ | |||||
| torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) | |||||
| assert torch.all(target == \ | |||||
| torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) | |||||
| def test_predictions_batch_01(): | def test_predictions_batch_01(): | ||||
| @@ -38,10 +149,13 @@ def test_predictions_batch_01(): | |||||
| fam_pred = RelationFamilyPredictions([ rel_pred ]) | fam_pred = RelationFamilyPredictions([ rel_pred ]) | ||||
| pred = Predictions([ fam_pred ]) | pred = Predictions([ fam_pred ]) | ||||
| batch = PredictionsBatch(pred, part_type='train', batch_size=1) | |||||
| pred_flat = flatten_predictions(pred, part_type='train') | |||||
| batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) | |||||
| count = 0 | count = 0 | ||||
| lst = [] | lst = [] | ||||
| for (input, target) in batch: | |||||
| for indices in batch: | |||||
| (input, target) = gather_batch_indices(pred_flat, indices) | |||||
| assert len(input) == 1 | assert len(input) == 1 | ||||
| assert len(target) == 1 | assert len(target) == 1 | ||||
| lst.append((input[0], target[0])) | lst.append((input[0], target[0])) | ||||