|
- from .fastconv import FastConvLayer
- from .bulkdec import BulkDecodeLayer
- from .input import OneHotInputLayer
- from .trainprep import PreparedData
- import torch
- import types
- from typing import List, \
- Union, \
- Callable
-
-
- class FastModel(torch.nn.Module):
- def __init__(self, prep_d: PreparedData,
- layer_dimensions: List[int] = [32, 64],
- keep_prob: float = 1.,
- rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
- dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- self._check_params(prep_d, layer_dimensions, rel_activation,
- layer_activation, dec_activation)
-
- self.prep_d = prep_d
- self.layer_dimensions = layer_dimensions
- self.keep_prob = float(keep_prob)
- self.rel_activation = rel_activation
- self.layer_activation = layer_activation
- self.dec_activation = dec_activation
-
- self.seq = None
- self.build()
-
- def build(self):
- in_layer = OneHotInputLayer(self.prep_d)
- last_output_dim = in_layer.output_dim
- seq = [ in_layer ]
-
- for dim in self.layer_dimensions:
- conv_layer = FastConvLayer(input_dim = last_output_dim,
- output_dim = [dim] * len(self.prep_d.node_types),
- data = self.prep_d,
- keep_prob = self.keep_prob,
- rel_activation = self.rel_activation,
- layer_activation = self.layer_activation)
- last_output_dim = conv_layer.output_dim
- seq.append(conv_layer)
-
- dec_layer = BulkDecodeLayer(input_dim = last_output_dim,
- data = self.prep_d,
- keep_prob = self.keep_prob,
- activation = self.dec_activation)
- seq.append(dec_layer)
-
- seq = torch.nn.Sequential(*seq)
- self.seq = seq
-
- def forward(self, _):
- return self.seq(None)
-
- def _check_params(self, prep_d, layer_dimensions, rel_activation,
- layer_activation, dec_activation):
-
- if not isinstance(prep_d, PreparedData):
- raise TypeError('prep_d must be an instanced of PreparedData')
-
- if not isinstance(layer_dimensions, list):
- raise TypeError('layer_dimensions must be a list')
-
- if not isinstance(rel_activation, types.FunctionType):
- raise TypeError('rel_activation must be a function')
-
- if not isinstance(layer_activation, types.FunctionType):
- raise TypeError('layer_activation must be a function')
-
- if not isinstance(dec_activation, types.FunctionType):
- raise TypeError('dec_activation must be a function')
|