| @@ -0,0 +1,74 @@ | |||
| from .fastconv import FastConvLayer | |||
| from .bulkdec import BulkDecodeLayer | |||
| from .input import OneHotInputLayer | |||
| import torch | |||
| import types | |||
| class FastModel(torch.nn.Module): | |||
| def __init(self, prep_d: PreparedData, | |||
| layer_dimensions: List[int] = [32, 64], | |||
| keep_prob: float = 1., | |||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| self._check_params(prep_d, layer_dimensions, rel_activation, | |||
| layer_activation, dec_activation) | |||
| self.prep_d = prep_d | |||
| self.keep_prob = float(keep_prob) | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.dec_activation = dec_activation | |||
| self.seq = None | |||
| self.build() | |||
| def build(self): | |||
| in_layer = OneHotInputLayer(self.prep_d) | |||
| last_output_dim = in_layer.output_dim | |||
| seq = [ in_layer ] | |||
| for dim in self.layer_dimensions: | |||
| conv_layer = FastConvLayer(input_dim = last_output_dim, | |||
| output_dim = [dim] * len(self.prep_d.node_types), | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| rel_activation = self.rel_activation, | |||
| layer_activation = self.layer_activation) | |||
| last_output_dim = conv_layer.output_dim | |||
| seq.append(conv_layer) | |||
| dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| activation = self.dec_activation) | |||
| seq.append(dec_layer) | |||
| seq = torch.nn.Sequential(*seq) | |||
| self.seq = seq | |||
| def forward(self, _): | |||
| return self.seq(None) | |||
| def self._check_params(self, prep_d, layer_dimensions, rel_activation, | |||
| layer_activation, dec_activation): | |||
| if not isinstance(prep_d, PreparedData): | |||
| raise TypeError('prep_d must be an instanced of PreparedData') | |||
| if not isinstance(layer_dimensions, list): | |||
| raise TypeError('layer_dimensions must be a list') | |||
| if not isinstance(rel_activation, types.FunctionType): | |||
| raise TypeError('rel_activation must be a function') | |||
| if not isinstance(layer_activation, types.FunctionType): | |||
| raise TypeError('layer_activation must be a function') | |||
| if not isinstance(dec_activation, types.FunctionType): | |||
| raise TypeError('dec_activation must be a function') | |||