@@ -1,13 +1,72 @@ | |||
from icosagon.declayer import Predictions | |||
from .declayer import Predictions | |||
import torch | |||
from dataclasses import dataclass | |||
from .trainprep import PreparedData | |||
from typing import Tuple | |||
@dataclass | |||
class FlatPredictions(object): | |||
predictions: torch.Tensor | |||
truth: torch.Tensor | |||
part_type: str | |||
def flatten_predictions(pred: Predictions, part_type: str = 'train'): | |||
if not isinstance(pred, Predictions): | |||
raise TypeError('pred must be an instance of Predictions') | |||
if part_type not in ['train', 'val', 'test']: | |||
raise ValueError('part_type must be set to train, val or test') | |||
edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||
('edges_back_pos', 1), ('edges_back_neg', 0)] | |||
input = [] | |||
target = [] | |||
for fam in pred.relation_families: | |||
for rel in fam.relation_types: | |||
for (et, tgt) in edge_types: | |||
edge_pred = getattr(getattr(rel, et), part_type) | |||
input.append(edge_pred) | |||
target.append(torch.ones_like(edge_pred) * tgt) | |||
input = torch.cat(input) | |||
target = torch.cat(target) | |||
return FlatPredictions(input, target, part_type) | |||
@dataclass | |||
class BatchIndices(object): | |||
indices: torch.Tensor | |||
part_type: str | |||
def gather_batch_indices(pred: FlatPredictions, | |||
indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: | |||
if not isinstance(pred, FlatPredictions): | |||
raise TypeError('pred must be an instance of FlatPredictions') | |||
if not isinstance(indices, BatchIndices): | |||
raise TypeError('indices must be an instance of BatchIndices') | |||
if pred.part_type != indices.part_type: | |||
raise ValueError('part_type must be the same in pred and indices') | |||
return (pred.predictions[indices.indices], | |||
pred.truth[indices.indices]) | |||
class PredictionsBatch(object): | |||
def __init__(self, pred: Predictions, part_type: str = 'train', | |||
batch_size: int = 100, shuffle: bool = False) -> None: | |||
def __init__(self, prep_d: PreparedData, part_type: str = 'train', | |||
batch_size: int = 100, shuffle: bool = False, | |||
generator: torch.Generator = None) -> None: | |||
if not isinstance(pred, Predictions): | |||
raise TypeError('pred must be an instance of Predictions') | |||
if not isinstance(prep_d, PreparedData): | |||
raise TypeError('prep_d must be an instance of PreparedData') | |||
if part_type not in ['train', 'val', 'test']: | |||
raise ValueError('part_type must be set to train, val or test') | |||
@@ -16,32 +75,28 @@ class PredictionsBatch(object): | |||
shuffle = bool(shuffle) | |||
self.predictions = pred | |||
if generator is not None and not isinstance(generator, torch.Generator): | |||
raise TypeError('generator must be an instance of torch.Generator') | |||
self.prep_d = prep_d | |||
self.part_type = part_type | |||
self.batch_size = batch_size | |||
self.shuffle = shuffle | |||
self.generator = generator or torch.default_generator | |||
def __iter__(self): | |||
edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||
('edges_back_pos', 1), ('edges_back_neg', 0)] | |||
input = [] | |||
target = [] | |||
for fam in self.predictions.relation_families: | |||
count = 0 | |||
for fam in prep_d.relation_families: | |||
for rel in fam.relation_types: | |||
for (et, tgt) in edge_types: | |||
edge_pred = getattr(getattr(rel, et), self.part_type) | |||
input.append(edge_pred) | |||
target.append(torch.ones_like(edge_pred) * tgt) | |||
input = torch.cat(input) | |||
target = torch.cat(target) | |||
for et in ['edges_pos', 'edges_neg', | |||
'edges_back_pos', 'edges_back_neg']: | |||
count += len(getattr(getattr(rel, et), part_type)) | |||
self.total_edge_count = count | |||
def __iter__(self): | |||
values = torch.arange(self.total_edge_count) | |||
if self.shuffle: | |||
perm = torch.randperm(len(input)) | |||
input = input[perm] | |||
target = target[perm] | |||
perm = torch.randperm(len(values)) | |||
values = values[perm] | |||
for i in range(0, len(input), self.batch_size): | |||
yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) | |||
for i in range(0, len(values), self.batch_size): | |||
yield BatchIndices(values[i:i+self.batch_size], self.part_type) |
@@ -1,6 +1,8 @@ | |||
from .model import Model | |||
import torch | |||
from .batch import PredictionsBatch | |||
from .batch import PredictionsBatch, \ | |||
flatten_predictions, \ | |||
gather_batch_indices | |||
from typing import Callable | |||
from types import FunctionType | |||
@@ -9,7 +11,7 @@ class TrainLoop(object): | |||
def __init__(self, model: Model, lr: float = 0.001, | |||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||
torch.nn.functional.binary_cross_entropy_with_logits, | |||
batch_size: int = 100) -> None: | |||
batch_size: int = 100, generator: torch.Generator = None) -> None: | |||
if not isinstance(model, Model): | |||
raise TypeError('model must be an instance of Model') | |||
@@ -21,10 +23,14 @@ class TrainLoop(object): | |||
batch_size = int(batch_size) | |||
if generator is not None and not isinstance(generator, torch.Generator): | |||
raise TypeError('generator must be an instance of torch.Generator') | |||
self.model = model | |||
self.lr = lr | |||
self.loss = loss | |||
self.batch_size = batch_size | |||
self.generator = generator or torch.default_generator | |||
self.opt = None | |||
@@ -35,22 +41,25 @@ class TrainLoop(object): | |||
self.opt = opt | |||
def run_epoch(self): | |||
pred = self.model(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||
n = len(list(iter(batch))) | |||
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | |||
generator=self.generator) | |||
# pred = self.model(None) | |||
# n = len(list(iter(batch))) | |||
loss_sum = 0 | |||
for i in range(n): | |||
for indices in batch: | |||
self.opt.zero_grad() | |||
pred = self.model(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
seed = torch.rand(1).item() | |||
rng_state = torch.get_rng_state() | |||
torch.manual_seed(seed) | |||
it = iter(batch) | |||
torch.set_rng_state(rng_state) | |||
for k in range(i): | |||
_ = next(it) | |||
(input, target) = next(it) | |||
pred = flatten_predictions(pred) | |||
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
# seed = torch.rand(1).item() | |||
# rng_state = torch.get_rng_state() | |||
# torch.manual_seed(seed) | |||
#it = iter(batch) | |||
#torch.set_rng_state(rng_state) | |||
#for k in range(i): | |||
#_ = next(it) | |||
#(input, target) = next(it) | |||
(input, target) = gather_batch_indices(pred, indices) | |||
loss = self.loss(input, target) | |||
loss.backward() | |||
self.opt.step() | |||
@@ -1,4 +1,8 @@ | |||
from icosagon.batch import PredictionsBatch | |||
from icosagon.batch import PredictionsBatch, \ | |||
FlatPredictions, \ | |||
flatten_predictions, \ | |||
BatchIndices, \ | |||
gather_batch_indices | |||
from icosagon.declayer import Predictions, \ | |||
RelationPredictions, \ | |||
RelationFamilyPredictions | |||
@@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \ | |||
TrainValTest | |||
from icosagon.data import Data | |||
import torch | |||
import pytest | |||
def test_flat_predictions_01(): | |||
pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), | |||
torch.tensor([1, 0, 1, 0]), 'train') | |||
assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) | |||
assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) | |||
assert pred.part_type == 'train' | |||
def test_flatten_predictions_01(): | |||
rel_pred = RelationPredictions( | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
) | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
pred_flat = flatten_predictions(pred, part_type='train') | |||
assert torch.all(pred_flat.predictions == \ | |||
torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) | |||
assert torch.all(pred_flat.truth == \ | |||
torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) | |||
assert pred_flat.part_type == 'train' | |||
def test_flatten_predictions_02(): | |||
rel_pred = RelationPredictions( | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
) | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
pred_flat = flatten_predictions(pred, part_type='val') | |||
assert len(pred_flat.predictions) == 0 | |||
assert len(pred_flat.truth) == 0 | |||
assert pred_flat.part_type == 'val' | |||
def test_flatten_predictions_03(): | |||
rel_pred = RelationPredictions( | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
) | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
pred_flat = flatten_predictions(pred, part_type='test') | |||
assert len(pred_flat.predictions) == 0 | |||
assert len(pred_flat.truth) == 0 | |||
assert pred_flat.part_type == 'test' | |||
def test_flatten_predictions_04(): | |||
rel_pred = RelationPredictions( | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
) | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
with pytest.raises(TypeError): | |||
pred_flat = flatten_predictions(1, part_type='test') | |||
with pytest.raises(ValueError): | |||
pred_flat = flatten_predictions(pred, part_type='x') | |||
def test_batch_indices_01(): | |||
indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | |||
assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | |||
assert indices.part_type == 'train' | |||
def test_gather_batch_indices_01(): | |||
rel_pred = RelationPredictions( | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||
) | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
pred_flat = flatten_predictions(pred, part_type='train') | |||
indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') | |||
(input, target) = gather_batch_indices(pred_flat, indices) | |||
assert torch.all(input == \ | |||
torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) | |||
assert torch.all(target == \ | |||
torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) | |||
def test_predictions_batch_01(): | |||
@@ -38,10 +149,13 @@ def test_predictions_batch_01(): | |||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||
pred = Predictions([ fam_pred ]) | |||
batch = PredictionsBatch(pred, part_type='train', batch_size=1) | |||
pred_flat = flatten_predictions(pred, part_type='train') | |||
batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) | |||
count = 0 | |||
lst = [] | |||
for (input, target) in batch: | |||
for indices in batch: | |||
(input, target) = gather_batch_indices(pred_flat, indices) | |||
assert len(input) == 1 | |||
assert len(target) == 1 | |||
lst.append((input[0], target[0])) | |||