from .data import Data from typing import List, \ Callable from .trainprep import PreparedData import torch from .convlayer import DecagonLayer from .input import OneHotInputLayer from types import FunctionType from .declayer import DecodeLayer from .batch import PredictionsBatch class Model(torch.nn.Module): def __init__(self, prep_d: PreparedData, layer_dimensions: List[int] = [32, 64], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: super().__init__(**kwargs) if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') if not isinstance(layer_dimensions, list): raise TypeError('layer_dimensions must be a list') keep_prob = float(keep_prob) if not isinstance(rel_activation, FunctionType): raise TypeError('rel_activation must be a function') if not isinstance(layer_activation, FunctionType): raise TypeError('layer_activation must be a function') if not isinstance(dec_activation, FunctionType): raise TypeError('dec_activation must be a function') self.prep_d = prep_d self.layer_dimensions = layer_dimensions self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.dec_activation = dec_activation self.seq = None self.build() def build(self): in_layer = OneHotInputLayer(self.prep_d) last_output_dim = in_layer.output_dim seq = [ in_layer ] for dim in self.layer_dimensions: conv_layer = DecagonLayer(input_dim = last_output_dim, output_dim = [ dim ] * len(self.prep_d.node_types), data = self.prep_d, keep_prob = self.keep_prob, rel_activation = self.rel_activation, layer_activation = self.layer_activation) last_output_dim = conv_layer.output_dim seq.append(conv_layer) dec_layer = DecodeLayer(input_dim = last_output_dim, data = self.prep_d, keep_prob = self.keep_prob, activation = self.dec_activation) seq.append(dec_layer) seq = torch.nn.Sequential(*seq) self.seq = seq def forward(self, _): return self.seq(None)