from .data import Data from typing import List, \ Callable from .trainprep import PreparedData import torch from .convlayer import DecagonLayer from .input import OneHotInputLayer from types import FunctionType from .declayer import DecodeLayer from .batch import PredictionsBatch class Model(object): def __init__(self, prep_d: PreparedData, layer_dimensions: List[int] = [32, 64], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, batch_size: int = 100) -> None: if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') if not isinstance(layer_dimensions, list): raise TypeError('layer_dimensions must be a list') keep_prob = float(keep_prob) if not isinstance(rel_activation, FunctionType): raise TypeError('rel_activation must be a function') if not isinstance(layer_activation, FunctionType): raise TypeError('layer_activation must be a function') if not isinstance(dec_activation, FunctionType): raise TypeError('dec_activation must be a function') lr = float(lr) if not isinstance(loss, FunctionType): raise TypeError('loss must be a function') batch_size = int(batch_size) self.prep_d = prep_d self.layer_dimensions = layer_dimensions self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.dec_activation = dec_activation self.lr = lr self.loss = loss self.batch_size = batch_size self.seq = None self.opt = None self.build() def build(self): in_layer = OneHotInputLayer(self.prep_d) last_output_dim = in_layer.output_dim seq = [ in_layer ] for dim in self.layer_dimensions: conv_layer = DecagonLayer(input_dim = last_output_dim, output_dim = [ dim ] * len(self.prep_d.node_types), data = self.prep_d, keep_prob = self.keep_prob, rel_activation = self.rel_activation, layer_activation = self.layer_activation) last_output_dim = conv_layer.output_dim seq.append(conv_layer) dec_layer = DecodeLayer(input_dim = last_output_dim, data = self.prep_d, keep_prob = self.keep_prob, activation = self.dec_activation) seq.append(dec_layer) seq = torch.nn.Sequential(*seq) self.seq = seq opt = torch.optim.Adam(seq.parameters(), lr=self.lr) self.opt = opt def run_epoch(self): pred = self.seq(None) batch = PredictionsBatch(pred, batch_size=self.batch_size) n = len(list(iter(batch))) loss_sum = 0 for i in range(n): self.opt.zero_grad() pred = self.seq(None) batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True) seed = torch.rand(1).item() rng_state = torch.get_rng_state() torch.manual_seed(seed) it = iter(batch) torch.set_rng_state(rng_state) for k in range(i): _ = next(it) (input, target) = next(it) loss = self.loss(input, target) loss.backward() self.opt.step() loss_sum += loss.detach().cpu().item() return loss_sum def train(self, max_epochs): best_loss = None best_epoch = None for i in range(max_epochs): loss = self.run_epoch() if best_loss is None or loss < best_loss: best_loss = loss best_epoch = i return loss, best_loss, best_epoch