diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py index 18565aa..cc20b5a 100644 --- a/src/triacontagon/loop.py +++ b/src/triacontagon/loop.py @@ -1,40 +1,105 @@ from .model import Model from .batch import Batcher +from .sampling import negative_sample_data +from .data import Data -class TrainLoop(object): - def __init__(self, model: Model, - pos_batcher: Batcher, - neg_batcher: Batcher, - max_epochs: int = 50) -> None: +def _merge_pos_neg_batches(pos_batch, neg_batch): + assert len(pos_batch.edges) == len(neg_batch.edges) + assert pos_batch.vertex_type_row == neg_batch.vertex_type_row + assert pos_batch.vertex_type_column == neg_batch.vertex_type_column + assert pos_batch.relation_type_index == neg_batch.relation_type_index - if not isinstance(model, Model): - raise TypeError('model must be an instance of Model') + batch = TrainingBatch(pos_batch.vertex_type_row, + pos_batch.vertex_type_column, + pos_batch.relation_type_index, + torch.cat([ pos_batch.edges, neg_batch.edges ]), + torch.cat([ + torch.ones(len(pos_batch.edges)), + torch.zeros(len(neg_batch.edges)) + ])) + return batch - if not isinstance(pos_batcher, Batcher): - raise TypeError('pos_batcher must be an instance of Batcher') - if not isinstance(neg_batcher, Batcher): - raise TypeError('neg_batcher must be an instance of Batcher') +class TrainLoop(object): + def __init__(self, model: Model, + val_data: Data, test_data: Data, + initial_repr: List[torch.Tensor], + max_epochs: int = 50, + batch_size: int = 512, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ + torch.nn.functional.binary_cross_entropy_with_logits, + lr: float = 0.001) -> None: + + assert isinstance(model, Model) + assert isinstance(val_data, Data) + assert isinstance(test_data, Data) + assert callable(loss) self.model = model - self.pos_batcher = pos_batcher - self.neg_batcher = neg_batcher + self.test_data = test_data + self.initial_repr = list(initial_repr) self.max_epochs = int(num_epochs) + self.batch_size = int(batch_size) + self.loss = loss + self.lr = float(lr) + + self.pos_data = model.data + self.neg_data = negative_sample_data(model.data) + + self.pos_val_data = val_data + self.neg_val_data = negative_sample_data(val_data) + + self.batcher = DualBatcher(self.pos_data, self.neg_data, + batch_size=batch_size) + self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data) + + self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) def run_epoch(self) -> None: - pos_it = iter(self.pos_batcher) - neg_it = iter(self.neg_batcher) - - while True: - try: - pos_batch = next(pos_it) - neg_batch = next(neg_it) - except StopIteration: - break - if len(pos_batch.edges) != len(neg_batch.edges): - raise ValueError('Positive and negative batch should have same length') + loss_sum = 0. + for pos_batch, neg_batch in self.batcher: + batch = _merge_pos_neg_batches(pos_batch, neg_batch) + + self.opt.zero_grad() + last_layer_repr = self.model.convolve(self.initial_repr) + pred = self.model.decode(last_layer_repr, batch) + loss = self.loss(pred, batch.target_values) + loss.backward() + self.opt.step() + + loss = loss.detach().cpu().item() + loss_sum += loss + print('loss:', loss) + return loss_sum + + def validate_epoch(self): + loss_sum = 0. + for pos_batch, neg_batch in self.val_batcher: + batch = _merge_pos_neg_batches(pos_batch, neg_batch) + with torch.no_grad(): + last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True) + pred = self.model.decode(last_layer_repr, batch, eval_mode=True) + loss = self.loss(pred, batch.target_values) + loss = loss.detach().cpu().item() + loss_sum += loss + return loss_sum def run(self) -> None: + best_loss = float('inf') + epochs_without_improvement = 0 for epoch in range(self.max_epochs): - self.run_epoch() + print('Epoch', epoch) + loss_sum = self.run_epoch() + print('train loss_sum:', loss_sum) + loss_sum = self.validate_epoch() + print('val loss_sum:', loss_sum) + if loss_sum >= best_loss: + epochs_without_improvement += 1 + else: + epochs_without_improvement = 0 + best_loss = loss_sum + if epochs_without_improvement == 2: + print('Early stopping after epoch', epoch, 'due to no improvement') + + return (epoch, best_loss, loss_sum) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 1a41b4d..09e6de7 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -130,7 +130,7 @@ class Model(torch.nn.Module): torch.nn.Parameter(local_variation[i]) - def convolve(self, in_layer_repr: List[torch.Tensor]) -> \ + def convolve(self, in_layer_repr: List[torch.Tensor], eval_mode=False) -> \ List[torch.Tensor]: cur_layer_repr = in_layer_repr @@ -145,7 +145,7 @@ class Model(torch.nn.Module): num_relation_types = len(et.adjacency_matrices) x = cur_layer_repr[vt_col] - if self.keep_prob != 1: + if self.keep_prob != 1 and not eval_mode: x = dropout(x, self.keep_prob) # print('a, Layer:', i, 'x.shape:', x.shape) @@ -176,7 +176,7 @@ class Model(torch.nn.Module): return cur_layer_repr def decode(self, last_layer_repr: List[torch.Tensor], - batch: TrainingBatch) -> torch.Tensor: + batch: TrainingBatch, eval_mode=False) -> torch.Tensor: vt_row = batch.vertex_type_row vt_col = batch.vertex_type_column @@ -195,8 +195,9 @@ class Model(torch.nn.Module): in_row = in_row[batch.edges[:, 0]] in_col = in_col[batch.edges[:, 1]] - in_row = dropout(in_row, self.keep_prob) - in_col = dropout(in_col, self.keep_prob) + if self.keep_prob != 1 and not eval_mode: + in_row = dropout(in_row, self.keep_prob) + in_col = dropout(in_col, self.keep_prob) # in_row = in_row.to_dense() # in_col = in_col.to_dense()