| @@ -0,0 +1,127 @@ | |||||
| from .data import Data | |||||
| from typing import List | |||||
| from .trainprep import prepare_training, \ | |||||
| TrainValTest | |||||
| import torch | |||||
| from .convlayer import DecagonLayer | |||||
| from .input import OneHotInputLayer | |||||
| from types import FunctionType | |||||
| from .declayer import DecodeLayer | |||||
| from .batch import PredictionsBatch | |||||
| class Model(object): | |||||
| def __init__(self, data: Data, | |||||
| layer_dimensions: List[int] = [32, 64], | |||||
| ratios: TrainValTest = TrainValTest(.8, .1, .1), | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| lr: float = 0.001, | |||||
| loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits, | |||||
| batch_size: int = 100) -> None: | |||||
| if not isinstance(data, Data): | |||||
| raise TypeError('data must be an instance of Data') | |||||
| if not isinstance(layer_sizes, list): | |||||
| raise TypeError('layer_dimensions must be a list') | |||||
| if not isinstance(ratios, TrainValTest): | |||||
| raise TypeError('ratios must be an instance of TrainValTest') | |||||
| keep_prob = float(keep_prob) | |||||
| if not isinstance(rel_activation, FunctionType): | |||||
| raise TypeError('rel_activation must be a function') | |||||
| if not isinstance(layer_activation, FunctionType): | |||||
| raise TypeError('layer_activation must be a function') | |||||
| if not isinstance(dec_activation, FunctionType): | |||||
| raise TypeError('dec_activation must be a function') | |||||
| lr = float(lr) | |||||
| if not isinstance(loss, FunctionType): | |||||
| raise TypeError('loss must be a function') | |||||
| batch_size = int(batch_size) | |||||
| self.data = data | |||||
| self.layer_dimensions = layer_dimensions | |||||
| self.ratios = ratios | |||||
| self.keep_prob = keep_prob | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.dec_activation = dec_activation | |||||
| self.lr = lr | |||||
| self.loss = loss | |||||
| self.batch_size = batch_size | |||||
| self.build() | |||||
| def build(self): | |||||
| self.prep_d = prepare_training(self.data, self.ratios) | |||||
| in_layer = OneHotInputLayer(self.prep_d) | |||||
| last_output_dim = in_layer.output_dim | |||||
| seq = [ in_layer ] | |||||
| for dim in self.layer_dimensions: | |||||
| conv_layer = DecagonLayer(input_dim = last_output_dim, | |||||
| output_dim = [ dim ] * len(self.prep_d.node_types), | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| rel_activation = self.rel_activation, | |||||
| layer_activation = self.layer_activation) | |||||
| last_output_dim = conv_layer.output_dim | |||||
| seq.append(conv_layer) | |||||
| dec_layer = DecodeLayer(input_dim = last_output_dim, | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.dec_activation) | |||||
| seq.append(dec_layer) | |||||
| seq = torch.nn.Sequential(*seq) | |||||
| self.seq = seq | |||||
| opt = torch.optim.Adam(seq.parameters(), lr=self.lr) | |||||
| self.opt = opt | |||||
| def run_epoch(self): | |||||
| pred = self.seq(None) | |||||
| batch = PredictionsBatch(pred, self.batch_size) | |||||
| n = len(list(iter(batch))) | |||||
| loss_sum = 0 | |||||
| for i in range(n - 1): | |||||
| self.opt.zero_grad() | |||||
| pred = self.seq(None) | |||||
| batch = PredictionsBatch(pred, self.batch_size) | |||||
| seed = torch.rand(1).item() | |||||
| rng_state = torch.get_rng_state() | |||||
| torch.manual_seed(seed) | |||||
| it = iter(batch) | |||||
| torch.set_rng_state(rng_state) | |||||
| for k in range(i): | |||||
| _ = next(it) | |||||
| (input, target) = next(it) | |||||
| loss = self.loss(input, target) | |||||
| loss.backward() | |||||
| self.opt.optimize() | |||||
| loss_sum += loss.detach().cpu().item() | |||||
| return loss_sum | |||||
| def train(self, max_epochs): | |||||
| best_loss = None | |||||
| best_epoch = None | |||||
| for i in range(max_epochs): | |||||
| loss = self.run_epoch() | |||||
| if best_loss is None or loss < best_loss: | |||||
| best_loss = loss | |||||
| best_epoch = i | |||||
| return loss, best_loss, best_epoch | |||||