diff --git a/src/icosagon/model.py b/src/icosagon/model.py new file mode 100644 index 0000000..ba28bb7 --- /dev/null +++ b/src/icosagon/model.py @@ -0,0 +1,127 @@ +from .data import Data +from typing import List +from .trainprep import prepare_training, \ + TrainValTest +import torch +from .convlayer import DecagonLayer +from .input import OneHotInputLayer +from types import FunctionType +from .declayer import DecodeLayer +from .batch import PredictionsBatch + + +class Model(object): + def __init__(self, data: Data, + layer_dimensions: List[int] = [32, 64], + ratios: TrainValTest = TrainValTest(.8, .1, .1), + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, + dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + lr: float = 0.001, + loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, + batch_size: int = 100) -> None: + + if not isinstance(data, Data): + raise TypeError('data must be an instance of Data') + + if not isinstance(layer_sizes, list): + raise TypeError('layer_dimensions must be a list') + + if not isinstance(ratios, TrainValTest): + raise TypeError('ratios must be an instance of TrainValTest') + + keep_prob = float(keep_prob) + + if not isinstance(rel_activation, FunctionType): + raise TypeError('rel_activation must be a function') + + if not isinstance(layer_activation, FunctionType): + raise TypeError('layer_activation must be a function') + + if not isinstance(dec_activation, FunctionType): + raise TypeError('dec_activation must be a function') + + lr = float(lr) + + if not isinstance(loss, FunctionType): + raise TypeError('loss must be a function') + + batch_size = int(batch_size) + + self.data = data + self.layer_dimensions = layer_dimensions + self.ratios = ratios + self.keep_prob = keep_prob + self.rel_activation = rel_activation + self.layer_activation = layer_activation + self.dec_activation = dec_activation + self.lr = lr + self.loss = loss + self.batch_size = batch_size + + self.build() + + def build(self): + self.prep_d = prepare_training(self.data, self.ratios) + + in_layer = OneHotInputLayer(self.prep_d) + last_output_dim = in_layer.output_dim + seq = [ in_layer ] + + for dim in self.layer_dimensions: + conv_layer = DecagonLayer(input_dim = last_output_dim, + output_dim = [ dim ] * len(self.prep_d.node_types), + data = self.prep_d, + keep_prob = self.keep_prob, + rel_activation = self.rel_activation, + layer_activation = self.layer_activation) + last_output_dim = conv_layer.output_dim + seq.append(conv_layer) + + dec_layer = DecodeLayer(input_dim = last_output_dim, + data = self.prep_d, + keep_prob = self.keep_prob, + activation = self.dec_activation) + seq.append(dec_layer) + + seq = torch.nn.Sequential(*seq) + self.seq = seq + + opt = torch.optim.Adam(seq.parameters(), lr=self.lr) + self.opt = opt + + + def run_epoch(self): + pred = self.seq(None) + batch = PredictionsBatch(pred, self.batch_size) + n = len(list(iter(batch))) + loss_sum = 0 + for i in range(n - 1): + self.opt.zero_grad() + pred = self.seq(None) + batch = PredictionsBatch(pred, self.batch_size) + seed = torch.rand(1).item() + rng_state = torch.get_rng_state() + torch.manual_seed(seed) + it = iter(batch) + torch.set_rng_state(rng_state) + for k in range(i): + _ = next(it) + (input, target) = next(it) + loss = self.loss(input, target) + loss.backward() + self.opt.optimize() + loss_sum += loss.detach().cpu().item() + return loss_sum + + + def train(self, max_epochs): + best_loss = None + best_epoch = None + for i in range(max_epochs): + loss = self.run_epoch() + if best_loss is None or loss < best_loss: + best_loss = loss + best_epoch = i + return loss, best_loss, best_epoch