|
@@ -1,5 +1,6 @@ |
|
|
from .data import Data
|
|
|
from .data import Data
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
from typing import List, \
|
|
|
|
|
|
Callable
|
|
|
from .trainprep import prepare_training, \
|
|
|
from .trainprep import prepare_training, \
|
|
|
TrainValTest
|
|
|
TrainValTest
|
|
|
import torch
|
|
|
import torch
|
|
@@ -19,13 +20,13 @@ class Model(object): |
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
lr: float = 0.001,
|
|
|
lr: float = 0.001,
|
|
|
loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
|
|
|
|
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
batch_size: int = 100) -> None:
|
|
|
batch_size: int = 100) -> None:
|
|
|
|
|
|
|
|
|
if not isinstance(data, Data):
|
|
|
if not isinstance(data, Data):
|
|
|
raise TypeError('data must be an instance of Data')
|
|
|
raise TypeError('data must be an instance of Data')
|
|
|
|
|
|
|
|
|
if not isinstance(layer_sizes, list):
|
|
|
|
|
|
|
|
|
if not isinstance(layer_dimensions, list):
|
|
|
raise TypeError('layer_dimensions must be a list')
|
|
|
raise TypeError('layer_dimensions must be a list')
|
|
|
|
|
|
|
|
|
if not isinstance(ratios, TrainValTest):
|
|
|
if not isinstance(ratios, TrainValTest):
|
|
@@ -60,6 +61,10 @@ class Model(object): |
|
|
self.loss = loss
|
|
|
self.loss = loss
|
|
|
self.batch_size = batch_size
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
|
|
self.prep_d = None
|
|
|
|
|
|
self.seq = None
|
|
|
|
|
|
self.opt = None
|
|
|
|
|
|
|
|
|
self.build()
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
def build(self):
|
|
|
def build(self):
|
|
|