| @@ -0,0 +1,74 @@ | |||||
| from .fastconv import FastConvLayer | |||||
| from .bulkdec import BulkDecodeLayer | |||||
| from .input import OneHotInputLayer | |||||
| import torch | |||||
| import types | |||||
| class FastModel(torch.nn.Module): | |||||
| def __init(self, prep_d: PreparedData, | |||||
| layer_dimensions: List[int] = [32, 64], | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| self._check_params(prep_d, layer_dimensions, rel_activation, | |||||
| layer_activation, dec_activation) | |||||
| self.prep_d = prep_d | |||||
| self.keep_prob = float(keep_prob) | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.dec_activation = dec_activation | |||||
| self.seq = None | |||||
| self.build() | |||||
| def build(self): | |||||
| in_layer = OneHotInputLayer(self.prep_d) | |||||
| last_output_dim = in_layer.output_dim | |||||
| seq = [ in_layer ] | |||||
| for dim in self.layer_dimensions: | |||||
| conv_layer = FastConvLayer(input_dim = last_output_dim, | |||||
| output_dim = [dim] * len(self.prep_d.node_types), | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| rel_activation = self.rel_activation, | |||||
| layer_activation = self.layer_activation) | |||||
| last_output_dim = conv_layer.output_dim | |||||
| seq.append(conv_layer) | |||||
| dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.dec_activation) | |||||
| seq.append(dec_layer) | |||||
| seq = torch.nn.Sequential(*seq) | |||||
| self.seq = seq | |||||
| def forward(self, _): | |||||
| return self.seq(None) | |||||
| def self._check_params(self, prep_d, layer_dimensions, rel_activation, | |||||
| layer_activation, dec_activation): | |||||
| if not isinstance(prep_d, PreparedData): | |||||
| raise TypeError('prep_d must be an instanced of PreparedData') | |||||
| if not isinstance(layer_dimensions, list): | |||||
| raise TypeError('layer_dimensions must be a list') | |||||
| if not isinstance(rel_activation, types.FunctionType): | |||||
| raise TypeError('rel_activation must be a function') | |||||
| if not isinstance(layer_activation, types.FunctionType): | |||||
| raise TypeError('layer_activation must be a function') | |||||
| if not isinstance(dec_activation, types.FunctionType): | |||||
| raise TypeError('dec_activation must be a function') | |||||