IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add FastModel.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
f99f0bb919
共有 1 个文件被更改,包括 74 次插入0 次删除
  1. +74
    -0
      src/icosagon/fastmodel.py

+ 74
- 0
src/icosagon/fastmodel.py 查看文件

@@ -0,0 +1,74 @@
from .fastconv import FastConvLayer
from .bulkdec import BulkDecodeLayer
from .input import OneHotInputLayer
import torch
import types
class FastModel(torch.nn.Module):
def __init(self, prep_d: PreparedData,
layer_dimensions: List[int] = [32, 64],
keep_prob: float = 1.,
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
**kwargs) -> None:
super().__init__(**kwargs)
self._check_params(prep_d, layer_dimensions, rel_activation,
layer_activation, dec_activation)
self.prep_d = prep_d
self.keep_prob = float(keep_prob)
self.rel_activation = rel_activation
self.layer_activation = layer_activation
self.dec_activation = dec_activation
self.seq = None
self.build()
def build(self):
in_layer = OneHotInputLayer(self.prep_d)
last_output_dim = in_layer.output_dim
seq = [ in_layer ]
for dim in self.layer_dimensions:
conv_layer = FastConvLayer(input_dim = last_output_dim,
output_dim = [dim] * len(self.prep_d.node_types),
data = self.prep_d,
keep_prob = self.keep_prob,
rel_activation = self.rel_activation,
layer_activation = self.layer_activation)
last_output_dim = conv_layer.output_dim
seq.append(conv_layer)
dec_layer = BulkDecodeLayer(input_dim = last_output_dim,
data = self.prep_d,
keep_prob = self.keep_prob,
activation = self.dec_activation)
seq.append(dec_layer)
seq = torch.nn.Sequential(*seq)
self.seq = seq
def forward(self, _):
return self.seq(None)
def self._check_params(self, prep_d, layer_dimensions, rel_activation,
layer_activation, dec_activation):
if not isinstance(prep_d, PreparedData):
raise TypeError('prep_d must be an instanced of PreparedData')
if not isinstance(layer_dimensions, list):
raise TypeError('layer_dimensions must be a list')
if not isinstance(rel_activation, types.FunctionType):
raise TypeError('rel_activation must be a function')
if not isinstance(layer_activation, types.FunctionType):
raise TypeError('layer_activation must be a function')
if not isinstance(dec_activation, types.FunctionType):
raise TypeError('dec_activation must be a function')

正在加载...
取消
保存