from .model import Model, \ TrainingBatch from .batch import Batcher from .sampling import negative_sample_data from .data import Data import torch from typing import List, \ Callable def _merge_pos_neg_batches(pos_batch, neg_batch): assert len(pos_batch.edges) == len(neg_batch.edges) assert pos_batch.vertex_type_row == neg_batch.vertex_type_row assert pos_batch.vertex_type_column == neg_batch.vertex_type_column assert pos_batch.relation_type_index == neg_batch.relation_type_index batch = TrainingBatch(pos_batch.vertex_type_row, pos_batch.vertex_type_column, pos_batch.relation_type_index, torch.cat([ pos_batch.edges, neg_batch.edges ]), torch.cat([ torch.ones(len(pos_batch.edges)), torch.zeros(len(neg_batch.edges)) ])) return batch class TrainLoop(object): def __init__(self, model: Model, val_data: Data, test_data: Data, initial_repr: List[torch.Tensor], max_epochs: int = 50, batch_size: int = 512, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, lr: float = 0.001) -> None: assert isinstance(model, Model) assert isinstance(val_data, Data) assert isinstance(test_data, Data) assert callable(loss) self.model = model self.test_data = test_data self.initial_repr = list(initial_repr) self.max_epochs = int(num_epochs) self.batch_size = int(batch_size) self.loss = loss self.lr = float(lr) self.pos_data = model.data self.neg_data = negative_sample_data(model.data) self.pos_val_data = val_data self.neg_val_data = negative_sample_data(val_data) self.batcher = DualBatcher(self.pos_data, self.neg_data, batch_size=batch_size) self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data) self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) def run_epoch(self) -> None: loss_sum = 0. for pos_batch, neg_batch in self.batcher: batch = _merge_pos_neg_batches(pos_batch, neg_batch) self.opt.zero_grad() last_layer_repr = self.model.convolve(self.initial_repr) pred = self.model.decode(last_layer_repr, batch) loss = self.loss(pred, batch.target_values) loss.backward() self.opt.step() loss = loss.detach().cpu().item() loss_sum += loss print('loss:', loss) return loss_sum def validate_epoch(self): loss_sum = 0. for pos_batch, neg_batch in self.val_batcher: batch = _merge_pos_neg_batches(pos_batch, neg_batch) with torch.no_grad(): last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True) pred = self.model.decode(last_layer_repr, batch, eval_mode=True) loss = self.loss(pred, batch.target_values) loss = loss.detach().cpu().item() loss_sum += loss return loss_sum def run(self) -> None: best_loss = float('inf') epochs_without_improvement = 0 for epoch in range(self.max_epochs): print('Epoch', epoch) loss_sum = self.run_epoch() print('train loss_sum:', loss_sum) loss_sum = self.validate_epoch() print('val loss_sum:', loss_sum) if loss_sum >= best_loss: epochs_without_improvement += 1 else: epochs_without_improvement = 0 best_loss = loss_sum if epochs_without_improvement == 2: print('Early stopping after epoch', epoch, 'due to no improvement') return (epoch, best_loss, loss_sum)