@@ -10,16 +10,16 @@ from .declayer import DecodeLayer | |||
from .batch import PredictionsBatch | |||
class Model(object): | |||
class Model(torch.nn.Module): | |||
def __init__(self, prep_d: PreparedData, | |||
layer_dimensions: List[int] = [32, 64], | |||
keep_prob: float = 1., | |||
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
lr: float = 0.001, | |||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, | |||
batch_size: int = 100) -> None: | |||
**kwargs) -> None: | |||
super().__init__(**kwargs) | |||
if not isinstance(prep_d, PreparedData): | |||
raise TypeError('prep_d must be an instance of PreparedData') | |||
@@ -38,25 +38,14 @@ class Model(object): | |||
if not isinstance(dec_activation, FunctionType): | |||
raise TypeError('dec_activation must be a function') | |||
lr = float(lr) | |||
if not isinstance(loss, FunctionType): | |||
raise TypeError('loss must be a function') | |||
batch_size = int(batch_size) | |||
self.prep_d = prep_d | |||
self.layer_dimensions = layer_dimensions | |||
self.keep_prob = keep_prob | |||
self.rel_activation = rel_activation | |||
self.layer_activation = layer_activation | |||
self.dec_activation = dec_activation | |||
self.lr = lr | |||
self.loss = loss | |||
self.batch_size = batch_size | |||
self.seq = None | |||
self.opt = None | |||
self.build() | |||
@@ -84,40 +73,5 @@ class Model(object): | |||
seq = torch.nn.Sequential(*seq) | |||
self.seq = seq | |||
opt = torch.optim.Adam(seq.parameters(), lr=self.lr) | |||
self.opt = opt | |||
def run_epoch(self): | |||
pred = self.seq(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||
n = len(list(iter(batch))) | |||
loss_sum = 0 | |||
for i in range(n): | |||
self.opt.zero_grad() | |||
pred = self.seq(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
seed = torch.rand(1).item() | |||
rng_state = torch.get_rng_state() | |||
torch.manual_seed(seed) | |||
it = iter(batch) | |||
torch.set_rng_state(rng_state) | |||
for k in range(i): | |||
_ = next(it) | |||
(input, target) = next(it) | |||
loss = self.loss(input, target) | |||
loss.backward() | |||
self.opt.step() | |||
loss_sum += loss.detach().cpu().item() | |||
return loss_sum | |||
def train(self, max_epochs): | |||
best_loss = None | |||
best_epoch = None | |||
for i in range(max_epochs): | |||
loss = self.run_epoch() | |||
if best_loss is None or loss < best_loss: | |||
best_loss = loss | |||
best_epoch = i | |||
return loss, best_loss, best_epoch | |||
def forward(self, _): | |||
return self.seq(None) |
@@ -0,0 +1,69 @@ | |||
from .model import Model | |||
import torch | |||
from .batch import PredictionsBatch | |||
from typing import Callable | |||
from types import FunctionType | |||
class TrainLoop(object): | |||
def __init__(self, model: Model, lr: float = 0.001, | |||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||
torch.nn.functional.binary_cross_entropy_with_logits, | |||
batch_size: int = 100) -> None: | |||
if not isinstance(model, Model): | |||
raise TypeError('model must be an instance of Model') | |||
lr = float(lr) | |||
if not isinstance(loss, FunctionType): | |||
raise TypeError('loss must be a function') | |||
batch_size = int(batch_size) | |||
self.model = model | |||
self.lr = lr | |||
self.loss = loss | |||
self.batch_size = batch_size | |||
self.opt = None | |||
self.build() | |||
def build(self) -> None: | |||
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |||
self.opt = opt | |||
def run_epoch(self): | |||
pred = self.model(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size) | |||
n = len(list(iter(batch))) | |||
loss_sum = 0 | |||
for i in range(n): | |||
self.opt.zero_grad() | |||
pred = self.seq(None) | |||
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) | |||
seed = torch.rand(1).item() | |||
rng_state = torch.get_rng_state() | |||
torch.manual_seed(seed) | |||
it = iter(batch) | |||
torch.set_rng_state(rng_state) | |||
for k in range(i): | |||
_ = next(it) | |||
(input, target) = next(it) | |||
loss = self.loss(input, target) | |||
loss.backward() | |||
self.opt.step() | |||
loss_sum += loss.detach().cpu().item() | |||
return loss_sum | |||
def train(self, max_epochs): | |||
best_loss = None | |||
best_epoch = None | |||
for i in range(max_epochs): | |||
loss = self.run_epoch() | |||
if best_loss is None or loss < best_loss: | |||
best_loss = loss | |||
best_epoch = i | |||
return loss, best_loss, best_epoch |
@@ -0,0 +1,24 @@ | |||
from icosagon.data import Data | |||
from icosagon.trainprep import prepare_training, \ | |||
TrainValTest | |||
from icosagon.model import Model | |||
from icosagon.trainloop import TrainLoop | |||
import torch | |||
def test_train_loop_01(): | |||
d = Data() | |||
d.add_node_type('Dummy', 10) | |||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) | |||
prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
m = Model(prep_d) | |||
loop = TrainLoop(m) | |||
assert loop.model == m | |||
assert loop.lr == 0.001 | |||
assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits | |||
assert loop.batch_size == 100 |