|
@@ -1,12 +1,16 @@ |
|
|
from .fastconv import FastConvLayer
|
|
|
from .fastconv import FastConvLayer
|
|
|
from .bulkdec import BulkDecodeLayer
|
|
|
from .bulkdec import BulkDecodeLayer
|
|
|
from .input import OneHotInputLayer
|
|
|
from .input import OneHotInputLayer
|
|
|
|
|
|
from .trainprep import PreparedData
|
|
|
import torch
|
|
|
import torch
|
|
|
import types
|
|
|
import types
|
|
|
|
|
|
from typing import List, \
|
|
|
|
|
|
Union, \
|
|
|
|
|
|
Callable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastModel(torch.nn.Module):
|
|
|
class FastModel(torch.nn.Module):
|
|
|
def __init(self, prep_d: PreparedData,
|
|
|
|
|
|
|
|
|
def __init__(self, prep_d: PreparedData,
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
keep_prob: float = 1.,
|
|
|
keep_prob: float = 1.,
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
@@ -20,6 +24,7 @@ class FastModel(torch.nn.Module): |
|
|
layer_activation, dec_activation)
|
|
|
layer_activation, dec_activation)
|
|
|
|
|
|
|
|
|
self.prep_d = prep_d
|
|
|
self.prep_d = prep_d
|
|
|
|
|
|
self.layer_dimensions = layer_dimensions
|
|
|
self.keep_prob = float(keep_prob)
|
|
|
self.keep_prob = float(keep_prob)
|
|
|
self.rel_activation = rel_activation
|
|
|
self.rel_activation = rel_activation
|
|
|
self.layer_activation = layer_activation
|
|
|
self.layer_activation = layer_activation
|
|
@@ -55,7 +60,7 @@ class FastModel(torch.nn.Module): |
|
|
def forward(self, _):
|
|
|
def forward(self, _):
|
|
|
return self.seq(None)
|
|
|
return self.seq(None)
|
|
|
|
|
|
|
|
|
def self._check_params(self, prep_d, layer_dimensions, rel_activation,
|
|
|
|
|
|
|
|
|
def _check_params(self, prep_d, layer_dimensions, rel_activation,
|
|
|
layer_activation, dec_activation):
|
|
|
layer_activation, dec_activation):
|
|
|
|
|
|
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|