|
|
@@ -0,0 +1,74 @@ |
|
|
|
from .fastconv import FastConvLayer
|
|
|
|
from .bulkdec import BulkDecodeLayer
|
|
|
|
from .input import OneHotInputLayer
|
|
|
|
import torch
|
|
|
|
import types
|
|
|
|
|
|
|
|
|
|
|
|
class FastModel(torch.nn.Module):
|
|
|
|
def __init(self, prep_d: PreparedData,
|
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
self._check_params(prep_d, layer_dimensions, rel_activation,
|
|
|
|
layer_activation, dec_activation)
|
|
|
|
|
|
|
|
self.prep_d = prep_d
|
|
|
|
self.keep_prob = float(keep_prob)
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|
self.dec_activation = dec_activation
|
|
|
|
|
|
|
|
self.seq = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
in_layer = OneHotInputLayer(self.prep_d)
|
|
|
|
last_output_dim = in_layer.output_dim
|
|
|
|
seq = [ in_layer ]
|
|
|
|
|
|
|
|
for dim in self.layer_dimensions:
|
|
|
|
conv_layer = FastConvLayer(input_dim = last_output_dim,
|
|
|
|
output_dim = [dim] * len(self.prep_d.node_types),
|
|
|
|
data = self.prep_d,
|
|
|
|
keep_prob = self.keep_prob,
|
|
|
|
rel_activation = self.rel_activation,
|
|
|
|
layer_activation = self.layer_activation)
|
|
|
|
last_output_dim = conv_layer.output_dim
|
|
|
|
seq.append(conv_layer)
|
|
|
|
|
|
|
|
dec_layer = BulkDecodeLayer(input_dim = last_output_dim,
|
|
|
|
data = self.prep_d,
|
|
|
|
keep_prob = self.keep_prob,
|
|
|
|
activation = self.dec_activation)
|
|
|
|
seq.append(dec_layer)
|
|
|
|
|
|
|
|
seq = torch.nn.Sequential(*seq)
|
|
|
|
self.seq = seq
|
|
|
|
|
|
|
|
def forward(self, _):
|
|
|
|
return self.seq(None)
|
|
|
|
|
|
|
|
def self._check_params(self, prep_d, layer_dimensions, rel_activation,
|
|
|
|
layer_activation, dec_activation):
|
|
|
|
|
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|
|
raise TypeError('prep_d must be an instanced of PreparedData')
|
|
|
|
|
|
|
|
if not isinstance(layer_dimensions, list):
|
|
|
|
raise TypeError('layer_dimensions must be a list')
|
|
|
|
|
|
|
|
if not isinstance(rel_activation, types.FunctionType):
|
|
|
|
raise TypeError('rel_activation must be a function')
|
|
|
|
|
|
|
|
if not isinstance(layer_activation, types.FunctionType):
|
|
|
|
raise TypeError('layer_activation must be a function')
|
|
|
|
|
|
|
|
if not isinstance(dec_activation, types.FunctionType):
|
|
|
|
raise TypeError('dec_activation must be a function')
|