|
@@ -0,0 +1,127 @@ |
|
|
|
|
|
from .data import Data
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
from .trainprep import prepare_training, \
|
|
|
|
|
|
TrainValTest
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from .convlayer import DecagonLayer
|
|
|
|
|
|
from .input import OneHotInputLayer
|
|
|
|
|
|
from types import FunctionType
|
|
|
|
|
|
from .declayer import DecodeLayer
|
|
|
|
|
|
from .batch import PredictionsBatch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(object):
|
|
|
|
|
|
def __init__(self, data: Data,
|
|
|
|
|
|
layer_dimensions: List[int] = [32, 64],
|
|
|
|
|
|
ratios: TrainValTest = TrainValTest(.8, .1, .1),
|
|
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
|
|
|
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
|
|
lr: float = 0.001,
|
|
|
|
|
|
loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
|
|
batch_size: int = 100) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(data, Data):
|
|
|
|
|
|
raise TypeError('data must be an instance of Data')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(layer_sizes, list):
|
|
|
|
|
|
raise TypeError('layer_dimensions must be a list')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(ratios, TrainValTest):
|
|
|
|
|
|
raise TypeError('ratios must be an instance of TrainValTest')
|
|
|
|
|
|
|
|
|
|
|
|
keep_prob = float(keep_prob)
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(rel_activation, FunctionType):
|
|
|
|
|
|
raise TypeError('rel_activation must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(layer_activation, FunctionType):
|
|
|
|
|
|
raise TypeError('layer_activation must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(dec_activation, FunctionType):
|
|
|
|
|
|
raise TypeError('dec_activation must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
lr = float(lr)
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(loss, FunctionType):
|
|
|
|
|
|
raise TypeError('loss must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = int(batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
|
|
self.layer_dimensions = layer_dimensions
|
|
|
|
|
|
self.ratios = ratios
|
|
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|
|
|
self.dec_activation = dec_activation
|
|
|
|
|
|
self.lr = lr
|
|
|
|
|
|
self.loss = loss
|
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
|
|
self.prep_d = prepare_training(self.data, self.ratios)
|
|
|
|
|
|
|
|
|
|
|
|
in_layer = OneHotInputLayer(self.prep_d)
|
|
|
|
|
|
last_output_dim = in_layer.output_dim
|
|
|
|
|
|
seq = [ in_layer ]
|
|
|
|
|
|
|
|
|
|
|
|
for dim in self.layer_dimensions:
|
|
|
|
|
|
conv_layer = DecagonLayer(input_dim = last_output_dim,
|
|
|
|
|
|
output_dim = [ dim ] * len(self.prep_d.node_types),
|
|
|
|
|
|
data = self.prep_d,
|
|
|
|
|
|
keep_prob = self.keep_prob,
|
|
|
|
|
|
rel_activation = self.rel_activation,
|
|
|
|
|
|
layer_activation = self.layer_activation)
|
|
|
|
|
|
last_output_dim = conv_layer.output_dim
|
|
|
|
|
|
seq.append(conv_layer)
|
|
|
|
|
|
|
|
|
|
|
|
dec_layer = DecodeLayer(input_dim = last_output_dim,
|
|
|
|
|
|
data = self.prep_d,
|
|
|
|
|
|
keep_prob = self.keep_prob,
|
|
|
|
|
|
activation = self.dec_activation)
|
|
|
|
|
|
seq.append(dec_layer)
|
|
|
|
|
|
|
|
|
|
|
|
seq = torch.nn.Sequential(*seq)
|
|
|
|
|
|
self.seq = seq
|
|
|
|
|
|
|
|
|
|
|
|
opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
|
|
|
|
|
|
self.opt = opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_epoch(self):
|
|
|
|
|
|
pred = self.seq(None)
|
|
|
|
|
|
batch = PredictionsBatch(pred, self.batch_size)
|
|
|
|
|
|
n = len(list(iter(batch)))
|
|
|
|
|
|
loss_sum = 0
|
|
|
|
|
|
for i in range(n - 1):
|
|
|
|
|
|
self.opt.zero_grad()
|
|
|
|
|
|
pred = self.seq(None)
|
|
|
|
|
|
batch = PredictionsBatch(pred, self.batch_size)
|
|
|
|
|
|
seed = torch.rand(1).item()
|
|
|
|
|
|
rng_state = torch.get_rng_state()
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
it = iter(batch)
|
|
|
|
|
|
torch.set_rng_state(rng_state)
|
|
|
|
|
|
for k in range(i):
|
|
|
|
|
|
_ = next(it)
|
|
|
|
|
|
(input, target) = next(it)
|
|
|
|
|
|
loss = self.loss(input, target)
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
self.opt.optimize()
|
|
|
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, max_epochs):
|
|
|
|
|
|
best_loss = None
|
|
|
|
|
|
best_epoch = None
|
|
|
|
|
|
for i in range(max_epochs):
|
|
|
|
|
|
loss = self.run_epoch()
|
|
|
|
|
|
if best_loss is None or loss < best_loss:
|
|
|
|
|
|
best_loss = loss
|
|
|
|
|
|
best_epoch = i
|
|
|
|
|
|
return loss, best_loss, best_epoch
|