|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from .model import Model, \
- TrainingBatch
- from .batch import Batcher
- from .sampling import negative_sample_data
- from .data import Data
- import torch
- from typing import List, \
- Callable
-
-
- def _merge_pos_neg_batches(pos_batch, neg_batch):
- assert len(pos_batch.edges) == len(neg_batch.edges)
- assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
- assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
- assert pos_batch.relation_type_index == neg_batch.relation_type_index
-
- batch = TrainingBatch(pos_batch.vertex_type_row,
- pos_batch.vertex_type_column,
- pos_batch.relation_type_index,
- torch.cat([ pos_batch.edges, neg_batch.edges ]),
- torch.cat([
- torch.ones(len(pos_batch.edges)),
- torch.zeros(len(neg_batch.edges))
- ]))
- return batch
-
-
- class TrainLoop(object):
- def __init__(self, model: Model,
- val_data: Data, test_data: Data,
- initial_repr: List[torch.Tensor],
- max_epochs: int = 50,
- batch_size: int = 512,
- loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
- torch.nn.functional.binary_cross_entropy_with_logits,
- lr: float = 0.001) -> None:
-
- assert isinstance(model, Model)
- assert isinstance(val_data, Data)
- assert isinstance(test_data, Data)
- assert callable(loss)
-
- self.model = model
- self.test_data = test_data
- self.initial_repr = list(initial_repr)
- self.max_epochs = int(max_epochs)
- self.batch_size = int(batch_size)
- self.loss = loss
- self.lr = float(lr)
-
- self.pos_data = model.data
- self.neg_data = negative_sample_data(model.data)
-
- self.pos_val_data = val_data
- self.neg_val_data = negative_sample_data(val_data)
-
- self.batcher = DualBatcher(self.pos_data, self.neg_data,
- batch_size=batch_size)
- self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
-
- self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
-
- def run_epoch(self) -> None:
- loss_sum = 0.
- for pos_batch, neg_batch in self.batcher:
- batch = _merge_pos_neg_batches(pos_batch, neg_batch)
-
- self.opt.zero_grad()
- last_layer_repr = self.model.convolve(self.initial_repr)
- pred = self.model.decode(last_layer_repr, batch)
- loss = self.loss(pred, batch.target_values)
- loss.backward()
- self.opt.step()
-
- loss = loss.detach().cpu().item()
- loss_sum += loss
- print('loss:', loss)
- return loss_sum
-
- def validate_epoch(self):
- loss_sum = 0.
- for pos_batch, neg_batch in self.val_batcher:
- batch = _merge_pos_neg_batches(pos_batch, neg_batch)
- with torch.no_grad():
- last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
- pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
- loss = self.loss(pred, batch.target_values)
- loss = loss.detach().cpu().item()
- loss_sum += loss
- return loss_sum
-
- def run(self) -> None:
- best_loss = float('inf')
- epochs_without_improvement = 0
- for epoch in range(self.max_epochs):
- print('Epoch', epoch)
- loss_sum = self.run_epoch()
- print('train loss_sum:', loss_sum)
- loss_sum = self.validate_epoch()
- print('val loss_sum:', loss_sum)
- if loss_sum >= best_loss:
- epochs_without_improvement += 1
- else:
- epochs_without_improvement = 0
- best_loss = loss_sum
- if epochs_without_improvement == 2:
- print('Early stopping after epoch', epoch, 'due to no improvement')
-
- return (epoch, best_loss, loss_sum)
|