From 9366687239d7b6b4fd296624045a7633dff4a691 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 18 Jun 2020 13:00:17 +0200 Subject: [PATCH] Add TrainLoop. --- src/icosagon/model.py | 58 +++------------------------ src/icosagon/trainloop.py | 69 ++++++++++++++++++++++++++++++++ tests/icosagon/test_trainloop.py | 24 +++++++++++ 3 files changed, 99 insertions(+), 52 deletions(-) create mode 100644 src/icosagon/trainloop.py create mode 100644 tests/icosagon/test_trainloop.py diff --git a/src/icosagon/model.py b/src/icosagon/model.py index 7261129..1c9e413 100644 --- a/src/icosagon/model.py +++ b/src/icosagon/model.py @@ -10,16 +10,16 @@ from .declayer import DecodeLayer from .batch import PredictionsBatch -class Model(object): +class Model(torch.nn.Module): def __init__(self, prep_d: PreparedData, layer_dimensions: List[int] = [32, 64], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, - lr: float = 0.001, - loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, - batch_size: int = 100) -> None: + **kwargs) -> None: + + super().__init__(**kwargs) if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') @@ -38,25 +38,14 @@ class Model(object): if not isinstance(dec_activation, FunctionType): raise TypeError('dec_activation must be a function') - lr = float(lr) - - if not isinstance(loss, FunctionType): - raise TypeError('loss must be a function') - - batch_size = int(batch_size) - self.prep_d = prep_d self.layer_dimensions = layer_dimensions self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.dec_activation = dec_activation - self.lr = lr - self.loss = loss - self.batch_size = batch_size self.seq = None - self.opt = None self.build() @@ -84,40 +73,5 @@ class Model(object): seq = torch.nn.Sequential(*seq) self.seq = seq - opt = torch.optim.Adam(seq.parameters(), lr=self.lr) - self.opt = opt - - - def run_epoch(self): - pred = self.seq(None) - batch = PredictionsBatch(pred, batch_size=self.batch_size) - n = len(list(iter(batch))) - loss_sum = 0 - for i in range(n): - self.opt.zero_grad() - pred = self.seq(None) - batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) - seed = torch.rand(1).item() - rng_state = torch.get_rng_state() - torch.manual_seed(seed) - it = iter(batch) - torch.set_rng_state(rng_state) - for k in range(i): - _ = next(it) - (input, target) = next(it) - loss = self.loss(input, target) - loss.backward() - self.opt.step() - loss_sum += loss.detach().cpu().item() - return loss_sum - - - def train(self, max_epochs): - best_loss = None - best_epoch = None - for i in range(max_epochs): - loss = self.run_epoch() - if best_loss is None or loss < best_loss: - best_loss = loss - best_epoch = i - return loss, best_loss, best_epoch + def forward(self, _): + return self.seq(None) diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py new file mode 100644 index 0000000..051019e --- /dev/null +++ b/src/icosagon/trainloop.py @@ -0,0 +1,69 @@ +from .model import Model +import torch +from .batch import PredictionsBatch +from typing import Callable +from types import FunctionType + + +class TrainLoop(object): + def __init__(self, model: Model, lr: float = 0.001, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ + torch.nn.functional.binary_cross_entropy_with_logits, + batch_size: int = 100) -> None: + + if not isinstance(model, Model): + raise TypeError('model must be an instance of Model') + + lr = float(lr) + + if not isinstance(loss, FunctionType): + raise TypeError('loss must be a function') + + batch_size = int(batch_size) + + self.model = model + self.lr = lr + self.loss = loss + self.batch_size = batch_size + + self.opt = None + + self.build() + + def build(self) -> None: + opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.opt = opt + + def run_epoch(self): + pred = self.model(None) + batch = PredictionsBatch(pred, batch_size=self.batch_size) + n = len(list(iter(batch))) + loss_sum = 0 + for i in range(n): + self.opt.zero_grad() + pred = self.seq(None) + batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) + seed = torch.rand(1).item() + rng_state = torch.get_rng_state() + torch.manual_seed(seed) + it = iter(batch) + torch.set_rng_state(rng_state) + for k in range(i): + _ = next(it) + (input, target) = next(it) + loss = self.loss(input, target) + loss.backward() + self.opt.step() + loss_sum += loss.detach().cpu().item() + return loss_sum + + + def train(self, max_epochs): + best_loss = None + best_epoch = None + for i in range(max_epochs): + loss = self.run_epoch() + if best_loss is None or loss < best_loss: + best_loss = loss + best_epoch = i + return loss, best_loss, best_epoch diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py new file mode 100644 index 0000000..2476c9c --- /dev/null +++ b/tests/icosagon/test_trainloop.py @@ -0,0 +1,24 @@ +from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest +from icosagon.model import Model +from icosagon.trainloop import TrainLoop +import torch + + +def test_train_loop_01(): + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + m = Model(prep_d) + + loop = TrainLoop(m) + + assert loop.model == m + assert loop.lr == 0.001 + assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits + assert loop.batch_size == 100