|
|
@@ -1,40 +1,105 @@ |
|
|
|
from .model import Model
|
|
|
|
from .batch import Batcher
|
|
|
|
from .sampling import negative_sample_data
|
|
|
|
from .data import Data
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop(object):
|
|
|
|
def __init__(self, model: Model,
|
|
|
|
pos_batcher: Batcher,
|
|
|
|
neg_batcher: Batcher,
|
|
|
|
max_epochs: int = 50) -> None:
|
|
|
|
def _merge_pos_neg_batches(pos_batch, neg_batch):
|
|
|
|
assert len(pos_batch.edges) == len(neg_batch.edges)
|
|
|
|
assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
|
|
|
|
assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
|
|
|
|
assert pos_batch.relation_type_index == neg_batch.relation_type_index
|
|
|
|
|
|
|
|
if not isinstance(model, Model):
|
|
|
|
raise TypeError('model must be an instance of Model')
|
|
|
|
batch = TrainingBatch(pos_batch.vertex_type_row,
|
|
|
|
pos_batch.vertex_type_column,
|
|
|
|
pos_batch.relation_type_index,
|
|
|
|
torch.cat([ pos_batch.edges, neg_batch.edges ]),
|
|
|
|
torch.cat([
|
|
|
|
torch.ones(len(pos_batch.edges)),
|
|
|
|
torch.zeros(len(neg_batch.edges))
|
|
|
|
]))
|
|
|
|
return batch
|
|
|
|
|
|
|
|
if not isinstance(pos_batcher, Batcher):
|
|
|
|
raise TypeError('pos_batcher must be an instance of Batcher')
|
|
|
|
|
|
|
|
if not isinstance(neg_batcher, Batcher):
|
|
|
|
raise TypeError('neg_batcher must be an instance of Batcher')
|
|
|
|
class TrainLoop(object):
|
|
|
|
def __init__(self, model: Model,
|
|
|
|
val_data: Data, test_data: Data,
|
|
|
|
initial_repr: List[torch.Tensor],
|
|
|
|
max_epochs: int = 50,
|
|
|
|
batch_size: int = 512,
|
|
|
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
|
|
|
|
torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
lr: float = 0.001) -> None:
|
|
|
|
|
|
|
|
assert isinstance(model, Model)
|
|
|
|
assert isinstance(val_data, Data)
|
|
|
|
assert isinstance(test_data, Data)
|
|
|
|
assert callable(loss)
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
self.pos_batcher = pos_batcher
|
|
|
|
self.neg_batcher = neg_batcher
|
|
|
|
self.test_data = test_data
|
|
|
|
self.initial_repr = list(initial_repr)
|
|
|
|
self.max_epochs = int(num_epochs)
|
|
|
|
self.batch_size = int(batch_size)
|
|
|
|
self.loss = loss
|
|
|
|
self.lr = float(lr)
|
|
|
|
|
|
|
|
self.pos_data = model.data
|
|
|
|
self.neg_data = negative_sample_data(model.data)
|
|
|
|
|
|
|
|
self.pos_val_data = val_data
|
|
|
|
self.neg_val_data = negative_sample_data(val_data)
|
|
|
|
|
|
|
|
self.batcher = DualBatcher(self.pos_data, self.neg_data,
|
|
|
|
batch_size=batch_size)
|
|
|
|
self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
|
|
|
|
|
|
|
|
self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
|
|
|
|
|
|
|
|
def run_epoch(self) -> None:
|
|
|
|
pos_it = iter(self.pos_batcher)
|
|
|
|
neg_it = iter(self.neg_batcher)
|
|
|
|
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
pos_batch = next(pos_it)
|
|
|
|
neg_batch = next(neg_it)
|
|
|
|
except StopIteration:
|
|
|
|
break
|
|
|
|
if len(pos_batch.edges) != len(neg_batch.edges):
|
|
|
|
raise ValueError('Positive and negative batch should have same length')
|
|
|
|
loss_sum = 0.
|
|
|
|
for pos_batch, neg_batch in self.batcher:
|
|
|
|
batch = _merge_pos_neg_batches(pos_batch, neg_batch)
|
|
|
|
|
|
|
|
self.opt.zero_grad()
|
|
|
|
last_layer_repr = self.model.convolve(self.initial_repr)
|
|
|
|
pred = self.model.decode(last_layer_repr, batch)
|
|
|
|
loss = self.loss(pred, batch.target_values)
|
|
|
|
loss.backward()
|
|
|
|
self.opt.step()
|
|
|
|
|
|
|
|
loss = loss.detach().cpu().item()
|
|
|
|
loss_sum += loss
|
|
|
|
print('loss:', loss)
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
def validate_epoch(self):
|
|
|
|
loss_sum = 0.
|
|
|
|
for pos_batch, neg_batch in self.val_batcher:
|
|
|
|
batch = _merge_pos_neg_batches(pos_batch, neg_batch)
|
|
|
|
with torch.no_grad():
|
|
|
|
last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
|
|
|
|
pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
|
|
|
|
loss = self.loss(pred, batch.target_values)
|
|
|
|
loss = loss.detach().cpu().item()
|
|
|
|
loss_sum += loss
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
best_loss = float('inf')
|
|
|
|
epochs_without_improvement = 0
|
|
|
|
for epoch in range(self.max_epochs):
|
|
|
|
self.run_epoch()
|
|
|
|
print('Epoch', epoch)
|
|
|
|
loss_sum = self.run_epoch()
|
|
|
|
print('train loss_sum:', loss_sum)
|
|
|
|
loss_sum = self.validate_epoch()
|
|
|
|
print('val loss_sum:', loss_sum)
|
|
|
|
if loss_sum >= best_loss:
|
|
|
|
epochs_without_improvement += 1
|
|
|
|
else:
|
|
|
|
epochs_without_improvement = 0
|
|
|
|
best_loss = loss_sum
|
|
|
|
if epochs_without_improvement == 2:
|
|
|
|
print('Early stopping after epoch', epoch, 'due to no improvement')
|
|
|
|
|
|
|
|
return (epoch, best_loss, loss_sum)
|