From f99f0bb9195d5a7fe54b820c20fe4664263f16d6 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 28 Jul 2020 11:39:34 +0200 Subject: [PATCH] Add FastModel. --- src/icosagon/fastmodel.py | 74 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 src/icosagon/fastmodel.py diff --git a/src/icosagon/fastmodel.py b/src/icosagon/fastmodel.py new file mode 100644 index 0000000..7c30906 --- /dev/null +++ b/src/icosagon/fastmodel.py @@ -0,0 +1,74 @@ +from .fastconv import FastConvLayer +from .bulkdec import BulkDecodeLayer +from .input import OneHotInputLayer +import torch +import types + + +class FastModel(torch.nn.Module): + def __init(self, prep_d: PreparedData, + layer_dimensions: List[int] = [32, 64], + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, + dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + **kwargs) -> None: + + super().__init__(**kwargs) + + self._check_params(prep_d, layer_dimensions, rel_activation, + layer_activation, dec_activation) + + self.prep_d = prep_d + self.keep_prob = float(keep_prob) + self.rel_activation = rel_activation + self.layer_activation = layer_activation + self.dec_activation = dec_activation + + self.seq = None + self.build() + + def build(self): + in_layer = OneHotInputLayer(self.prep_d) + last_output_dim = in_layer.output_dim + seq = [ in_layer ] + + for dim in self.layer_dimensions: + conv_layer = FastConvLayer(input_dim = last_output_dim, + output_dim = [dim] * len(self.prep_d.node_types), + data = self.prep_d, + keep_prob = self.keep_prob, + rel_activation = self.rel_activation, + layer_activation = self.layer_activation) + last_output_dim = conv_layer.output_dim + seq.append(conv_layer) + + dec_layer = BulkDecodeLayer(input_dim = last_output_dim, + data = self.prep_d, + keep_prob = self.keep_prob, + activation = self.dec_activation) + seq.append(dec_layer) + + seq = torch.nn.Sequential(*seq) + self.seq = seq + + def forward(self, _): + return self.seq(None) + + def self._check_params(self, prep_d, layer_dimensions, rel_activation, + layer_activation, dec_activation): + + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instanced of PreparedData') + + if not isinstance(layer_dimensions, list): + raise TypeError('layer_dimensions must be a list') + + if not isinstance(rel_activation, types.FunctionType): + raise TypeError('rel_activation must be a function') + + if not isinstance(layer_activation, types.FunctionType): + raise TypeError('layer_activation must be a function') + + if not isinstance(dec_activation, types.FunctionType): + raise TypeError('dec_activation must be a function')