@@ -1,13 +1,72 @@ | |||||
from icosagon.declayer import Predictions | |||||
from .declayer import Predictions | |||||
import torch | import torch | ||||
from dataclasses import dataclass | |||||
from .trainprep import PreparedData | |||||
from typing import Tuple | |||||
@dataclass | |||||
class FlatPredictions(object): | |||||
predictions: torch.Tensor | |||||
truth: torch.Tensor | |||||
part_type: str | |||||
def flatten_predictions(pred: Predictions, part_type: str = 'train'): | |||||
if not isinstance(pred, Predictions): | |||||
raise TypeError('pred must be an instance of Predictions') | |||||
if part_type not in ['train', 'val', 'test']: | |||||
raise ValueError('part_type must be set to train, val or test') | |||||
edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||||
('edges_back_pos', 1), ('edges_back_neg', 0)] | |||||
input = [] | |||||
target = [] | |||||
for fam in pred.relation_families: | |||||
for rel in fam.relation_types: | |||||
for (et, tgt) in edge_types: | |||||
edge_pred = getattr(getattr(rel, et), part_type) | |||||
input.append(edge_pred) | |||||
target.append(torch.ones_like(edge_pred) * tgt) | |||||
input = torch.cat(input) | |||||
target = torch.cat(target) | |||||
return FlatPredictions(input, target, part_type) | |||||
@dataclass | |||||
class BatchIndices(object): | |||||
indices: torch.Tensor | |||||
part_type: str | |||||
def gather_batch_indices(pred: FlatPredictions, | |||||
indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
if not isinstance(pred, FlatPredictions): | |||||
raise TypeError('pred must be an instance of FlatPredictions') | |||||
if not isinstance(indices, BatchIndices): | |||||
raise TypeError('indices must be an instance of BatchIndices') | |||||
if pred.part_type != indices.part_type: | |||||
raise ValueError('part_type must be the same in pred and indices') | |||||
return (pred.predictions[indices.indices], | |||||
pred.truth[indices.indices]) | |||||
class PredictionsBatch(object): | class PredictionsBatch(object): | ||||
def __init__(self, pred: Predictions, part_type: str = 'train', | |||||
batch_size: int = 100, shuffle: bool = False) -> None: | |||||
def __init__(self, prep_d: PreparedData, part_type: str = 'train', | |||||
batch_size: int = 100, shuffle: bool = False, | |||||
generator: torch.Generator = None) -> None: | |||||
if not isinstance(pred, Predictions): | |||||
raise TypeError('pred must be an instance of Predictions') | |||||
if not isinstance(prep_d, PreparedData): | |||||
raise TypeError('prep_d must be an instance of PreparedData') | |||||
if part_type not in ['train', 'val', 'test']: | if part_type not in ['train', 'val', 'test']: | ||||
raise ValueError('part_type must be set to train, val or test') | raise ValueError('part_type must be set to train, val or test') | ||||
@@ -16,32 +75,28 @@ class PredictionsBatch(object): | |||||
shuffle = bool(shuffle) | shuffle = bool(shuffle) | ||||
self.predictions = pred | |||||
if generator is not None and not isinstance(generator, torch.Generator): | |||||
raise TypeError('generator must be an instance of torch.Generator') | |||||
self.prep_d = prep_d | |||||
self.part_type = part_type | self.part_type = part_type | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.generator = generator or torch.default_generator | |||||
def __iter__(self): | |||||
edge_types = [('edges_pos', 1), ('edges_neg', 0), | |||||
('edges_back_pos', 1), ('edges_back_neg', 0)] | |||||
input = [] | |||||
target = [] | |||||
for fam in self.predictions.relation_families: | |||||
count = 0 | |||||
for fam in prep_d.relation_families: | |||||
for rel in fam.relation_types: | for rel in fam.relation_types: | ||||
for (et, tgt) in edge_types: | |||||
edge_pred = getattr(getattr(rel, et), self.part_type) | |||||
input.append(edge_pred) | |||||
target.append(torch.ones_like(edge_pred) * tgt) | |||||
input = torch.cat(input) | |||||
target = torch.cat(target) | |||||
for et in ['edges_pos', 'edges_neg', | |||||
'edges_back_pos', 'edges_back_neg']: | |||||
count += len(getattr(getattr(rel, et), part_type)) | |||||
self.total_edge_count = count | |||||
def __iter__(self): | |||||
values = torch.arange(self.total_edge_count) | |||||
if self.shuffle: | if self.shuffle: | ||||
perm = torch.randperm(len(input)) | |||||
input = input[perm] | |||||
target = target[perm] | |||||
perm = torch.randperm(len(values)) | |||||
values = values[perm] | |||||
for i in range(0, len(input), self.batch_size): | |||||
yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) | |||||
for i in range(0, len(values), self.batch_size): | |||||
yield BatchIndices(values[i:i+self.batch_size], self.part_type) |
@@ -1,6 +1,8 @@ | |||||
from .model import Model | from .model import Model | ||||
import torch | import torch | ||||
from .batch import PredictionsBatch | |||||
from .batch import PredictionsBatch, \ | |||||
flatten_predictions, \ | |||||
gather_batch_indices | |||||
from typing import Callable | from typing import Callable | ||||
from types import FunctionType | from types import FunctionType | ||||
@@ -9,7 +11,7 @@ class TrainLoop(object): | |||||
def __init__(self, model: Model, lr: float = 0.001, | def __init__(self, model: Model, lr: float = 0.001, | ||||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | ||||
torch.nn.functional.binary_cross_entropy_with_logits, | torch.nn.functional.binary_cross_entropy_with_logits, | ||||
batch_size: int = 100) -> None: | |||||
batch_size: int = 100, generator: torch.Generator = None) -> None: | |||||
if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
raise TypeError('model must be an instance of Model') | raise TypeError('model must be an instance of Model') | ||||
@@ -21,10 +23,14 @@ class TrainLoop(object): | |||||
batch_size = int(batch_size) | batch_size = int(batch_size) | ||||
if generator is not None and not isinstance(generator, torch.Generator): | |||||
raise TypeError('generator must be an instance of torch.Generator') | |||||
self.model = model | self.model = model | ||||
self.lr = lr | self.lr = lr | ||||
self.loss = loss | self.loss = loss | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.generator = generator or torch.default_generator | |||||
self.opt = None | self.opt = None | ||||
@@ -35,22 +41,25 @@ class TrainLoop(object): | |||||
self.opt = opt | self.opt = opt | ||||
def run_epoch(self): | def run_epoch(self): | ||||
pred = self.model(None) | |||||
batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||||
n = len(list(iter(batch))) | |||||
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size, | |||||
generator=self.generator) | |||||
# pred = self.model(None) | |||||
# n = len(list(iter(batch))) | |||||
loss_sum = 0 | loss_sum = 0 | ||||
for i in range(n): | |||||
for indices in batch: | |||||
self.opt.zero_grad() | self.opt.zero_grad() | ||||
pred = self.model(None) | pred = self.model(None) | ||||
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||||
seed = torch.rand(1).item() | |||||
rng_state = torch.get_rng_state() | |||||
torch.manual_seed(seed) | |||||
it = iter(batch) | |||||
torch.set_rng_state(rng_state) | |||||
for k in range(i): | |||||
_ = next(it) | |||||
(input, target) = next(it) | |||||
pred = flatten_predictions(pred) | |||||
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||||
# seed = torch.rand(1).item() | |||||
# rng_state = torch.get_rng_state() | |||||
# torch.manual_seed(seed) | |||||
#it = iter(batch) | |||||
#torch.set_rng_state(rng_state) | |||||
#for k in range(i): | |||||
#_ = next(it) | |||||
#(input, target) = next(it) | |||||
(input, target) = gather_batch_indices(pred, indices) | |||||
loss = self.loss(input, target) | loss = self.loss(input, target) | ||||
loss.backward() | loss.backward() | ||||
self.opt.step() | self.opt.step() | ||||
@@ -1,4 +1,8 @@ | |||||
from icosagon.batch import PredictionsBatch | |||||
from icosagon.batch import PredictionsBatch, \ | |||||
FlatPredictions, \ | |||||
flatten_predictions, \ | |||||
BatchIndices, \ | |||||
gather_batch_indices | |||||
from icosagon.declayer import Predictions, \ | from icosagon.declayer import Predictions, \ | ||||
RelationPredictions, \ | RelationPredictions, \ | ||||
RelationFamilyPredictions | RelationFamilyPredictions | ||||
@@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \ | |||||
TrainValTest | TrainValTest | ||||
from icosagon.data import Data | from icosagon.data import Data | ||||
import torch | import torch | ||||
import pytest | |||||
def test_flat_predictions_01(): | |||||
pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), | |||||
torch.tensor([1, 0, 1, 0]), 'train') | |||||
assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) | |||||
assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) | |||||
assert pred.part_type == 'train' | |||||
def test_flatten_predictions_01(): | |||||
rel_pred = RelationPredictions( | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
) | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
pred = Predictions([ fam_pred ]) | |||||
pred_flat = flatten_predictions(pred, part_type='train') | |||||
assert torch.all(pred_flat.predictions == \ | |||||
torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) | |||||
assert torch.all(pred_flat.truth == \ | |||||
torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) | |||||
assert pred_flat.part_type == 'train' | |||||
def test_flatten_predictions_02(): | |||||
rel_pred = RelationPredictions( | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
) | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
pred = Predictions([ fam_pred ]) | |||||
pred_flat = flatten_predictions(pred, part_type='val') | |||||
assert len(pred_flat.predictions) == 0 | |||||
assert len(pred_flat.truth) == 0 | |||||
assert pred_flat.part_type == 'val' | |||||
def test_flatten_predictions_03(): | |||||
rel_pred = RelationPredictions( | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
) | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
pred = Predictions([ fam_pred ]) | |||||
pred_flat = flatten_predictions(pred, part_type='test') | |||||
assert len(pred_flat.predictions) == 0 | |||||
assert len(pred_flat.truth) == 0 | |||||
assert pred_flat.part_type == 'test' | |||||
def test_flatten_predictions_04(): | |||||
rel_pred = RelationPredictions( | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
) | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
pred = Predictions([ fam_pred ]) | |||||
with pytest.raises(TypeError): | |||||
pred_flat = flatten_predictions(1, part_type='test') | |||||
with pytest.raises(ValueError): | |||||
pred_flat = flatten_predictions(pred, part_type='x') | |||||
def test_batch_indices_01(): | |||||
indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') | |||||
assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) | |||||
assert indices.part_type == 'train' | |||||
def test_gather_batch_indices_01(): | |||||
rel_pred = RelationPredictions( | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), | |||||
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) | |||||
) | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | |||||
pred = Predictions([ fam_pred ]) | |||||
pred_flat = flatten_predictions(pred, part_type='train') | |||||
indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') | |||||
(input, target) = gather_batch_indices(pred_flat, indices) | |||||
assert torch.all(input == \ | |||||
torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) | |||||
assert torch.all(target == \ | |||||
torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) | |||||
def test_predictions_batch_01(): | def test_predictions_batch_01(): | ||||
@@ -38,10 +149,13 @@ def test_predictions_batch_01(): | |||||
fam_pred = RelationFamilyPredictions([ rel_pred ]) | fam_pred = RelationFamilyPredictions([ rel_pred ]) | ||||
pred = Predictions([ fam_pred ]) | pred = Predictions([ fam_pred ]) | ||||
batch = PredictionsBatch(pred, part_type='train', batch_size=1) | |||||
pred_flat = flatten_predictions(pred, part_type='train') | |||||
batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) | |||||
count = 0 | count = 0 | ||||
lst = [] | lst = [] | ||||
for (input, target) in batch: | |||||
for indices in batch: | |||||
(input, target) = gather_batch_indices(pred_flat, indices) | |||||
assert len(input) == 1 | assert len(input) == 1 | ||||
assert len(target) == 1 | assert len(target) == 1 | ||||
lst.append((input[0], target[0])) | lst.append((input[0], target[0])) | ||||