| @@ -0,0 +1,127 @@ | |||
| from .data import Data | |||
| from typing import List | |||
| from .trainprep import prepare_training, \ | |||
| TrainValTest | |||
| import torch | |||
| from .convlayer import DecagonLayer | |||
| from .input import OneHotInputLayer | |||
| from types import FunctionType | |||
| from .declayer import DecodeLayer | |||
| from .batch import PredictionsBatch | |||
| class Model(object): | |||
| def __init__(self, data: Data, | |||
| layer_dimensions: List[int] = [32, 64], | |||
| ratios: TrainValTest = TrainValTest(.8, .1, .1), | |||
| keep_prob: float = 1., | |||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| lr: float = 0.001, | |||
| loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, | |||
| batch_size: int = 100) -> None: | |||
| if not isinstance(data, Data): | |||
| raise TypeError('data must be an instance of Data') | |||
| if not isinstance(layer_sizes, list): | |||
| raise TypeError('layer_dimensions must be a list') | |||
| if not isinstance(ratios, TrainValTest): | |||
| raise TypeError('ratios must be an instance of TrainValTest') | |||
| keep_prob = float(keep_prob) | |||
| if not isinstance(rel_activation, FunctionType): | |||
| raise TypeError('rel_activation must be a function') | |||
| if not isinstance(layer_activation, FunctionType): | |||
| raise TypeError('layer_activation must be a function') | |||
| if not isinstance(dec_activation, FunctionType): | |||
| raise TypeError('dec_activation must be a function') | |||
| lr = float(lr) | |||
| if not isinstance(loss, FunctionType): | |||
| raise TypeError('loss must be a function') | |||
| batch_size = int(batch_size) | |||
| self.data = data | |||
| self.layer_dimensions = layer_dimensions | |||
| self.ratios = ratios | |||
| self.keep_prob = keep_prob | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.dec_activation = dec_activation | |||
| self.lr = lr | |||
| self.loss = loss | |||
| self.batch_size = batch_size | |||
| self.build() | |||
| def build(self): | |||
| self.prep_d = prepare_training(self.data, self.ratios) | |||
| in_layer = OneHotInputLayer(self.prep_d) | |||
| last_output_dim = in_layer.output_dim | |||
| seq = [ in_layer ] | |||
| for dim in self.layer_dimensions: | |||
| conv_layer = DecagonLayer(input_dim = last_output_dim, | |||
| output_dim = [ dim ] * len(self.prep_d.node_types), | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| rel_activation = self.rel_activation, | |||
| layer_activation = self.layer_activation) | |||
| last_output_dim = conv_layer.output_dim | |||
| seq.append(conv_layer) | |||
| dec_layer = DecodeLayer(input_dim = last_output_dim, | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| activation = self.dec_activation) | |||
| seq.append(dec_layer) | |||
| seq = torch.nn.Sequential(*seq) | |||
| self.seq = seq | |||
| opt = torch.optim.Adam(seq.parameters(), lr=self.lr) | |||
| self.opt = opt | |||
| def run_epoch(self): | |||
| pred = self.seq(None) | |||
| batch = PredictionsBatch(pred, self.batch_size) | |||
| n = len(list(iter(batch))) | |||
| loss_sum = 0 | |||
| for i in range(n - 1): | |||
| self.opt.zero_grad() | |||
| pred = self.seq(None) | |||
| batch = PredictionsBatch(pred, self.batch_size) | |||
| seed = torch.rand(1).item() | |||
| rng_state = torch.get_rng_state() | |||
| torch.manual_seed(seed) | |||
| it = iter(batch) | |||
| torch.set_rng_state(rng_state) | |||
| for k in range(i): | |||
| _ = next(it) | |||
| (input, target) = next(it) | |||
| loss = self.loss(input, target) | |||
| loss.backward() | |||
| self.opt.optimize() | |||
| loss_sum += loss.detach().cpu().item() | |||
| return loss_sum | |||
| def train(self, max_epochs): | |||
| best_loss = None | |||
| best_epoch = None | |||
| for i in range(max_epochs): | |||
| loss = self.run_epoch() | |||
| if best_loss is None or loss < best_loss: | |||
| best_loss = loss | |||
| best_epoch = i | |||
| return loss, best_loss, best_epoch | |||