IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Stanislaw Adaszewski 4 лет назад
Родитель
Сommit
9366687239
3 измененных файлов: 99 добавлений и 52 удалений
  1. +6
    -52
      src/icosagon/model.py
  2. +69
    -0
      src/icosagon/trainloop.py
  3. +24
    -0
      tests/icosagon/test_trainloop.py

+ 6
- 52
src/icosagon/model.py Просмотреть файл

@@ -10,16 +10,16 @@ from .declayer import DecodeLayer
from .batch import PredictionsBatch
class Model(object):
class Model(torch.nn.Module):
def __init__(self, prep_d: PreparedData,
layer_dimensions: List[int] = [32, 64],
keep_prob: float = 1.,
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100) -> None:
**kwargs) -> None:
super().__init__(**kwargs)
if not isinstance(prep_d, PreparedData):
raise TypeError('prep_d must be an instance of PreparedData')
@@ -38,25 +38,14 @@ class Model(object):
if not isinstance(dec_activation, FunctionType):
raise TypeError('dec_activation must be a function')
lr = float(lr)
if not isinstance(loss, FunctionType):
raise TypeError('loss must be a function')
batch_size = int(batch_size)
self.prep_d = prep_d
self.layer_dimensions = layer_dimensions
self.keep_prob = keep_prob
self.rel_activation = rel_activation
self.layer_activation = layer_activation
self.dec_activation = dec_activation
self.lr = lr
self.loss = loss
self.batch_size = batch_size
self.seq = None
self.opt = None
self.build()
@@ -84,40 +73,5 @@ class Model(object):
seq = torch.nn.Sequential(*seq)
self.seq = seq
opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
self.opt = opt
def run_epoch(self):
pred = self.seq(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size)
n = len(list(iter(batch)))
loss_sum = 0
for i in range(n):
self.opt.zero_grad()
pred = self.seq(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
seed = torch.rand(1).item()
rng_state = torch.get_rng_state()
torch.manual_seed(seed)
it = iter(batch)
torch.set_rng_state(rng_state)
for k in range(i):
_ = next(it)
(input, target) = next(it)
loss = self.loss(input, target)
loss.backward()
self.opt.step()
loss_sum += loss.detach().cpu().item()
return loss_sum
def train(self, max_epochs):
best_loss = None
best_epoch = None
for i in range(max_epochs):
loss = self.run_epoch()
if best_loss is None or loss < best_loss:
best_loss = loss
best_epoch = i
return loss, best_loss, best_epoch
def forward(self, _):
return self.seq(None)

+ 69
- 0
src/icosagon/trainloop.py Просмотреть файл

@@ -0,0 +1,69 @@
from .model import Model
import torch
from .batch import PredictionsBatch
from typing import Callable
from types import FunctionType
class TrainLoop(object):
def __init__(self, model: Model, lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100) -> None:
if not isinstance(model, Model):
raise TypeError('model must be an instance of Model')
lr = float(lr)
if not isinstance(loss, FunctionType):
raise TypeError('loss must be a function')
batch_size = int(batch_size)
self.model = model
self.lr = lr
self.loss = loss
self.batch_size = batch_size
self.opt = None
self.build()
def build(self) -> None:
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.opt = opt
def run_epoch(self):
pred = self.model(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size)
n = len(list(iter(batch)))
loss_sum = 0
for i in range(n):
self.opt.zero_grad()
pred = self.seq(None)
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
seed = torch.rand(1).item()
rng_state = torch.get_rng_state()
torch.manual_seed(seed)
it = iter(batch)
torch.set_rng_state(rng_state)
for k in range(i):
_ = next(it)
(input, target) = next(it)
loss = self.loss(input, target)
loss.backward()
self.opt.step()
loss_sum += loss.detach().cpu().item()
return loss_sum
def train(self, max_epochs):
best_loss = None
best_epoch = None
for i in range(max_epochs):
loss = self.run_epoch()
if best_loss is None or loss < best_loss:
best_loss = loss
best_epoch = i
return loss, best_loss, best_epoch

+ 24
- 0
tests/icosagon/test_trainloop.py Просмотреть файл

@@ -0,0 +1,24 @@
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
from icosagon.model import Model
from icosagon.trainloop import TrainLoop
import torch
def test_train_loop_01():
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
m = Model(prep_d)
loop = TrainLoop(m)
assert loop.model == m
assert loop.lr == 0.001
assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
assert loop.batch_size == 100

Загрузка…
Отмена
Сохранить