from .fastconv import FastConvLayer from .bulkdec import BulkDecodeLayer from .input import OneHotInputLayer from .trainprep import PreparedData import torch import types from typing import List, \ Union, \ Callable class FastModel(torch.nn.Module): def __init__(self, prep_d: PreparedData, layer_dimensions: List[int] = [32, 64], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: super().__init__(**kwargs) self._check_params(prep_d, layer_dimensions, rel_activation, layer_activation, dec_activation) self.prep_d = prep_d self.layer_dimensions = layer_dimensions self.keep_prob = float(keep_prob) self.rel_activation = rel_activation self.layer_activation = layer_activation self.dec_activation = dec_activation self.seq = None self.build() def build(self): in_layer = OneHotInputLayer(self.prep_d) last_output_dim = in_layer.output_dim seq = [ in_layer ] for dim in self.layer_dimensions: conv_layer = FastConvLayer(input_dim = last_output_dim, output_dim = [dim] * len(self.prep_d.node_types), data = self.prep_d, keep_prob = self.keep_prob, rel_activation = self.rel_activation, layer_activation = self.layer_activation) last_output_dim = conv_layer.output_dim seq.append(conv_layer) dec_layer = BulkDecodeLayer(input_dim = last_output_dim, data = self.prep_d, keep_prob = self.keep_prob, activation = self.dec_activation) seq.append(dec_layer) seq = torch.nn.Sequential(*seq) self.seq = seq def forward(self, _): return self.seq(None) def _check_params(self, prep_d, layer_dimensions, rel_activation, layer_activation, dec_activation): if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instanced of PreparedData') if not isinstance(layer_dimensions, list): raise TypeError('layer_dimensions must be a list') if not isinstance(rel_activation, types.FunctionType): raise TypeError('rel_activation must be a function') if not isinstance(layer_activation, types.FunctionType): raise TypeError('layer_activation must be a function') if not isinstance(dec_activation, types.FunctionType): raise TypeError('dec_activation must be a function')