|
@@ -10,16 +10,16 @@ from .declayer import DecodeLayer |
|
|
from .batch import PredictionsBatch
|
|
|
from .batch import PredictionsBatch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(object):
|
|
|
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
|
def __init__(self, prep_d: PreparedData,
|
|
|
def __init__(self, prep_d: PreparedData,
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
keep_prob: float = 1.,
|
|
|
keep_prob: float = 1.,
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
lr: float = 0.001,
|
|
|
|
|
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
|
|
batch_size: int = 100) -> None:
|
|
|
|
|
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|
raise TypeError('prep_d must be an instance of PreparedData')
|
|
|
raise TypeError('prep_d must be an instance of PreparedData')
|
|
@@ -38,25 +38,14 @@ class Model(object): |
|
|
if not isinstance(dec_activation, FunctionType):
|
|
|
if not isinstance(dec_activation, FunctionType):
|
|
|
raise TypeError('dec_activation must be a function')
|
|
|
raise TypeError('dec_activation must be a function')
|
|
|
|
|
|
|
|
|
lr = float(lr)
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(loss, FunctionType):
|
|
|
|
|
|
raise TypeError('loss must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = int(batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
self.prep_d = prep_d
|
|
|
self.prep_d = prep_d
|
|
|
self.layer_dimensions = layer_dimensions
|
|
|
self.layer_dimensions = layer_dimensions
|
|
|
self.keep_prob = keep_prob
|
|
|
self.keep_prob = keep_prob
|
|
|
self.rel_activation = rel_activation
|
|
|
self.rel_activation = rel_activation
|
|
|
self.layer_activation = layer_activation
|
|
|
self.layer_activation = layer_activation
|
|
|
self.dec_activation = dec_activation
|
|
|
self.dec_activation = dec_activation
|
|
|
self.lr = lr
|
|
|
|
|
|
self.loss = loss
|
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
|
|
self.seq = None
|
|
|
self.seq = None
|
|
|
self.opt = None
|
|
|
|
|
|
|
|
|
|
|
|
self.build()
|
|
|
self.build()
|
|
|
|
|
|
|
|
@@ -84,40 +73,5 @@ class Model(object): |
|
|
seq = torch.nn.Sequential(*seq)
|
|
|
seq = torch.nn.Sequential(*seq)
|
|
|
self.seq = seq
|
|
|
self.seq = seq
|
|
|
|
|
|
|
|
|
opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
|
|
|
|
|
|
self.opt = opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_epoch(self):
|
|
|
|
|
|
pred = self.seq(None)
|
|
|
|
|
|
batch = PredictionsBatch(pred, batch_size=self.batch_size)
|
|
|
|
|
|
n = len(list(iter(batch)))
|
|
|
|
|
|
loss_sum = 0
|
|
|
|
|
|
for i in range(n):
|
|
|
|
|
|
self.opt.zero_grad()
|
|
|
|
|
|
pred = self.seq(None)
|
|
|
|
|
|
batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
|
|
|
|
|
|
seed = torch.rand(1).item()
|
|
|
|
|
|
rng_state = torch.get_rng_state()
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
it = iter(batch)
|
|
|
|
|
|
torch.set_rng_state(rng_state)
|
|
|
|
|
|
for k in range(i):
|
|
|
|
|
|
_ = next(it)
|
|
|
|
|
|
(input, target) = next(it)
|
|
|
|
|
|
loss = self.loss(input, target)
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
self.opt.step()
|
|
|
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, max_epochs):
|
|
|
|
|
|
best_loss = None
|
|
|
|
|
|
best_epoch = None
|
|
|
|
|
|
for i in range(max_epochs):
|
|
|
|
|
|
loss = self.run_epoch()
|
|
|
|
|
|
if best_loss is None or loss < best_loss:
|
|
|
|
|
|
best_loss = loss
|
|
|
|
|
|
best_epoch = i
|
|
|
|
|
|
return loss, best_loss, best_epoch
|
|
|
|
|
|
|
|
|
def forward(self, _):
|
|
|
|
|
|
return self.seq(None)
|