IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Introduce FlatPredictions, flatten_predictions(), BatchIndices, gather_batch_indices().

master
Stanislaw Adaszewski 3 years ago
parent
commit
c023d35c46
3 changed files with 222 additions and 44 deletions
  1. +81
    -26
      src/icosagon/batch.py
  2. +24
    -15
      src/icosagon/trainloop.py
  3. +117
    -3
      tests/icosagon/test_batch.py

+ 81
- 26
src/icosagon/batch.py View File

@@ -1,13 +1,72 @@
from icosagon.declayer import Predictions
from .declayer import Predictions
import torch
from dataclasses import dataclass
from .trainprep import PreparedData
from typing import Tuple
@dataclass
class FlatPredictions(object):
predictions: torch.Tensor
truth: torch.Tensor
part_type: str
def flatten_predictions(pred: Predictions, part_type: str = 'train'):
if not isinstance(pred, Predictions):
raise TypeError('pred must be an instance of Predictions')
if part_type not in ['train', 'val', 'test']:
raise ValueError('part_type must be set to train, val or test')
edge_types = [('edges_pos', 1), ('edges_neg', 0),
('edges_back_pos', 1), ('edges_back_neg', 0)]
input = []
target = []
for fam in pred.relation_families:
for rel in fam.relation_types:
for (et, tgt) in edge_types:
edge_pred = getattr(getattr(rel, et), part_type)
input.append(edge_pred)
target.append(torch.ones_like(edge_pred) * tgt)
input = torch.cat(input)
target = torch.cat(target)
return FlatPredictions(input, target, part_type)
@dataclass
class BatchIndices(object):
indices: torch.Tensor
part_type: str
def gather_batch_indices(pred: FlatPredictions,
indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]:
if not isinstance(pred, FlatPredictions):
raise TypeError('pred must be an instance of FlatPredictions')
if not isinstance(indices, BatchIndices):
raise TypeError('indices must be an instance of BatchIndices')
if pred.part_type != indices.part_type:
raise ValueError('part_type must be the same in pred and indices')
return (pred.predictions[indices.indices],
pred.truth[indices.indices])
class PredictionsBatch(object):
def __init__(self, pred: Predictions, part_type: str = 'train',
batch_size: int = 100, shuffle: bool = False) -> None:
def __init__(self, prep_d: PreparedData, part_type: str = 'train',
batch_size: int = 100, shuffle: bool = False,
generator: torch.Generator = None) -> None:
if not isinstance(pred, Predictions):
raise TypeError('pred must be an instance of Predictions')
if not isinstance(prep_d, PreparedData):
raise TypeError('prep_d must be an instance of PreparedData')
if part_type not in ['train', 'val', 'test']:
raise ValueError('part_type must be set to train, val or test')
@@ -16,32 +75,28 @@ class PredictionsBatch(object):
shuffle = bool(shuffle)
self.predictions = pred
if generator is not None and not isinstance(generator, torch.Generator):
raise TypeError('generator must be an instance of torch.Generator')
self.prep_d = prep_d
self.part_type = part_type
self.batch_size = batch_size
self.shuffle = shuffle
self.generator = generator or torch.default_generator
def __iter__(self):
edge_types = [('edges_pos', 1), ('edges_neg', 0),
('edges_back_pos', 1), ('edges_back_neg', 0)]
input = []
target = []
for fam in self.predictions.relation_families:
count = 0
for fam in prep_d.relation_families:
for rel in fam.relation_types:
for (et, tgt) in edge_types:
edge_pred = getattr(getattr(rel, et), self.part_type)
input.append(edge_pred)
target.append(torch.ones_like(edge_pred) * tgt)
input = torch.cat(input)
target = torch.cat(target)
for et in ['edges_pos', 'edges_neg',
'edges_back_pos', 'edges_back_neg']:
count += len(getattr(getattr(rel, et), part_type))
self.total_edge_count = count
def __iter__(self):
values = torch.arange(self.total_edge_count)
if self.shuffle:
perm = torch.randperm(len(input))
input = input[perm]
target = target[perm]
perm = torch.randperm(len(values))
values = values[perm]
for i in range(0, len(input), self.batch_size):
yield (input[i:i+self.batch_size], target[i:i+self.batch_size])
for i in range(0, len(values), self.batch_size):
yield BatchIndices(values[i:i+self.batch_size], self.part_type)

+ 24
- 15
src/icosagon/trainloop.py View File

@@ -1,6 +1,8 @@
from .model import Model
import torch
from .batch import PredictionsBatch
from .batch import PredictionsBatch, \
flatten_predictions, \
gather_batch_indices
from typing import Callable
from types import FunctionType
@@ -9,7 +11,7 @@ class TrainLoop(object):
def __init__(self, model: Model, lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100) -> None:
batch_size: int = 100, generator: torch.Generator = None) -> None:
if not isinstance(model, Model):
raise TypeError('model must be an instance of Model')
@@ -21,10 +23,14 @@ class TrainLoop(object):
batch_size = int(batch_size)
if generator is not None and not isinstance(generator, torch.Generator):
raise TypeError('generator must be an instance of torch.Generator')
self.model = model
self.lr = lr
self.loss = loss
self.batch_size = batch_size
self.generator = generator or torch.default_generator
self.opt = None
@@ -35,22 +41,25 @@ class TrainLoop(object):
self.opt = opt
def run_epoch(self):
pred = self.model(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size)
n = len(list(iter(batch)))
batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
generator=self.generator)
# pred = self.model(None)
# n = len(list(iter(batch)))
loss_sum = 0
for i in range(n):
for indices in batch:
self.opt.zero_grad()
pred = self.model(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
seed = torch.rand(1).item()
rng_state = torch.get_rng_state()
torch.manual_seed(seed)
it = iter(batch)
torch.set_rng_state(rng_state)
for k in range(i):
_ = next(it)
(input, target) = next(it)
pred = flatten_predictions(pred)
# batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
# seed = torch.rand(1).item()
# rng_state = torch.get_rng_state()
# torch.manual_seed(seed)
#it = iter(batch)
#torch.set_rng_state(rng_state)
#for k in range(i):
#_ = next(it)
#(input, target) = next(it)
(input, target) = gather_batch_indices(pred, indices)
loss = self.loss(input, target)
loss.backward()
self.opt.step()


+ 117
- 3
tests/icosagon/test_batch.py View File

@@ -1,4 +1,8 @@
from icosagon.batch import PredictionsBatch
from icosagon.batch import PredictionsBatch, \
FlatPredictions, \
flatten_predictions, \
BatchIndices, \
gather_batch_indices
from icosagon.declayer import Predictions, \
RelationPredictions, \
RelationFamilyPredictions
@@ -6,6 +10,113 @@ from icosagon.trainprep import prepare_training, \
TrainValTest
from icosagon.data import Data
import torch
import pytest
def test_flat_predictions_01():
pred = FlatPredictions(torch.tensor([0, 1, 0, 1]),
torch.tensor([1, 0, 1, 0]), 'train')
assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1]))
assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0]))
assert pred.part_type == 'train'
def test_flatten_predictions_01():
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
pred_flat = flatten_predictions(pred, part_type='train')
assert torch.all(pred_flat.predictions == \
torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32))
assert torch.all(pred_flat.truth == \
torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32))
assert pred_flat.part_type == 'train'
def test_flatten_predictions_02():
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
pred_flat = flatten_predictions(pred, part_type='val')
assert len(pred_flat.predictions) == 0
assert len(pred_flat.truth) == 0
assert pred_flat.part_type == 'val'
def test_flatten_predictions_03():
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
pred_flat = flatten_predictions(pred, part_type='test')
assert len(pred_flat.predictions) == 0
assert len(pred_flat.truth) == 0
assert pred_flat.part_type == 'test'
def test_flatten_predictions_04():
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
with pytest.raises(TypeError):
pred_flat = flatten_predictions(1, part_type='test')
with pytest.raises(ValueError):
pred_flat = flatten_predictions(pred, part_type='x')
def test_batch_indices_01():
indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
assert indices.part_type == 'train'
def test_gather_batch_indices_01():
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
pred_flat = flatten_predictions(pred, part_type='train')
indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train')
(input, target) = gather_batch_indices(pred_flat, indices)
assert torch.all(input == \
torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32))
assert torch.all(target == \
torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32))
def test_predictions_batch_01():
@@ -38,10 +149,13 @@ def test_predictions_batch_01():
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
batch = PredictionsBatch(pred, part_type='train', batch_size=1)
pred_flat = flatten_predictions(pred, part_type='train')
batch = PredictionsBatch(prep_d, part_type='train', batch_size=1)
count = 0
lst = []
for (input, target) in batch:
for indices in batch:
(input, target) = gather_batch_indices(pred_flat, indices)
assert len(input) == 1
assert len(target) == 1
lst.append((input[0], target[0]))


Loading…
Cancel
Save