|
- from .data import Data
- from typing import List, \
- Callable
- from .trainprep import prepare_training, \
- TrainValTest
- import torch
- from .convlayer import DecagonLayer
- from .input import OneHotInputLayer
- from types import FunctionType
- from .declayer import DecodeLayer
- from .batch import PredictionsBatch
-
-
- class Model(object):
- def __init__(self, data: Data,
- layer_dimensions: List[int] = [32, 64],
- ratios: TrainValTest = TrainValTest(.8, .1, .1),
- keep_prob: float = 1.,
- rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
- dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- lr: float = 0.001,
- loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
- batch_size: int = 100) -> None:
-
- if not isinstance(data, Data):
- raise TypeError('data must be an instance of Data')
-
- if not isinstance(layer_dimensions, list):
- raise TypeError('layer_dimensions must be a list')
-
- if not isinstance(ratios, TrainValTest):
- raise TypeError('ratios must be an instance of TrainValTest')
-
- keep_prob = float(keep_prob)
-
- if not isinstance(rel_activation, FunctionType):
- raise TypeError('rel_activation must be a function')
-
- if not isinstance(layer_activation, FunctionType):
- raise TypeError('layer_activation must be a function')
-
- if not isinstance(dec_activation, FunctionType):
- raise TypeError('dec_activation must be a function')
-
- lr = float(lr)
-
- if not isinstance(loss, FunctionType):
- raise TypeError('loss must be a function')
-
- batch_size = int(batch_size)
-
- self.data = data
- self.layer_dimensions = layer_dimensions
- self.ratios = ratios
- self.keep_prob = keep_prob
- self.rel_activation = rel_activation
- self.layer_activation = layer_activation
- self.dec_activation = dec_activation
- self.lr = lr
- self.loss = loss
- self.batch_size = batch_size
-
- self.prep_d = None
- self.seq = None
- self.opt = None
-
- self.build()
-
- def build(self):
- self.prep_d = prepare_training(self.data, self.ratios)
-
- in_layer = OneHotInputLayer(self.prep_d)
- last_output_dim = in_layer.output_dim
- seq = [ in_layer ]
-
- for dim in self.layer_dimensions:
- conv_layer = DecagonLayer(input_dim = last_output_dim,
- output_dim = [ dim ] * len(self.prep_d.node_types),
- data = self.prep_d,
- keep_prob = self.keep_prob,
- rel_activation = self.rel_activation,
- layer_activation = self.layer_activation)
- last_output_dim = conv_layer.output_dim
- seq.append(conv_layer)
-
- dec_layer = DecodeLayer(input_dim = last_output_dim,
- data = self.prep_d,
- keep_prob = self.keep_prob,
- activation = self.dec_activation)
- seq.append(dec_layer)
-
- seq = torch.nn.Sequential(*seq)
- self.seq = seq
-
- opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
- self.opt = opt
-
-
- def run_epoch(self):
- pred = self.seq(None)
- batch = PredictionsBatch(pred, self.batch_size)
- n = len(list(iter(batch)))
- loss_sum = 0
- for i in range(n - 1):
- self.opt.zero_grad()
- pred = self.seq(None)
- batch = PredictionsBatch(pred, self.batch_size, shuffle=True)
- seed = torch.rand(1).item()
- rng_state = torch.get_rng_state()
- torch.manual_seed(seed)
- it = iter(batch)
- torch.set_rng_state(rng_state)
- for k in range(i):
- _ = next(it)
- (input, target) = next(it)
- loss = self.loss(input, target)
- loss.backward()
- self.opt.optimize()
- loss_sum += loss.detach().cpu().item()
- return loss_sum
-
-
- def train(self, max_epochs):
- best_loss = None
- best_epoch = None
- for i in range(max_epochs):
- loss = self.run_epoch()
- if best_loss is None or loss < best_loss:
- best_loss = loss
- best_epoch = i
- return loss, best_loss, best_epoch
|