|  |  | @@ -1,40 +1,105 @@ | 
		
	
		
			
			|  |  |  | from .model import Model | 
		
	
		
			
			|  |  |  | from .batch import Batcher | 
		
	
		
			
			|  |  |  | from .sampling import negative_sample_data | 
		
	
		
			
			|  |  |  | from .data import Data | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class TrainLoop(object): | 
		
	
		
			
			|  |  |  | def __init__(self, model: Model, | 
		
	
		
			
			|  |  |  | pos_batcher: Batcher, | 
		
	
		
			
			|  |  |  | neg_batcher: Batcher, | 
		
	
		
			
			|  |  |  | max_epochs: int = 50) -> None: | 
		
	
		
			
			|  |  |  | def _merge_pos_neg_batches(pos_batch, neg_batch): | 
		
	
		
			
			|  |  |  | assert len(pos_batch.edges) == len(neg_batch.edges) | 
		
	
		
			
			|  |  |  | assert pos_batch.vertex_type_row == neg_batch.vertex_type_row | 
		
	
		
			
			|  |  |  | assert pos_batch.vertex_type_column == neg_batch.vertex_type_column | 
		
	
		
			
			|  |  |  | assert pos_batch.relation_type_index == neg_batch.relation_type_index | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(model, Model): | 
		
	
		
			
			|  |  |  | raise TypeError('model must be an instance of Model') | 
		
	
		
			
			|  |  |  | batch = TrainingBatch(pos_batch.vertex_type_row, | 
		
	
		
			
			|  |  |  | pos_batch.vertex_type_column, | 
		
	
		
			
			|  |  |  | pos_batch.relation_type_index, | 
		
	
		
			
			|  |  |  | torch.cat([ pos_batch.edges, neg_batch.edges ]), | 
		
	
		
			
			|  |  |  | torch.cat([ | 
		
	
		
			
			|  |  |  | torch.ones(len(pos_batch.edges)), | 
		
	
		
			
			|  |  |  | torch.zeros(len(neg_batch.edges)) | 
		
	
		
			
			|  |  |  | ])) | 
		
	
		
			
			|  |  |  | return batch | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(pos_batcher, Batcher): | 
		
	
		
			
			|  |  |  | raise TypeError('pos_batcher must be an instance of Batcher') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(neg_batcher, Batcher): | 
		
	
		
			
			|  |  |  | raise TypeError('neg_batcher must be an instance of Batcher') | 
		
	
		
			
			|  |  |  | class TrainLoop(object): | 
		
	
		
			
			|  |  |  | def __init__(self, model: Model, | 
		
	
		
			
			|  |  |  | val_data: Data, test_data: Data, | 
		
	
		
			
			|  |  |  | initial_repr: List[torch.Tensor], | 
		
	
		
			
			|  |  |  | max_epochs: int = 50, | 
		
	
		
			
			|  |  |  | batch_size: int = 512, | 
		
	
		
			
			|  |  |  | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | 
		
	
		
			
			|  |  |  | torch.nn.functional.binary_cross_entropy_with_logits, | 
		
	
		
			
			|  |  |  | lr: float = 0.001) -> None: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert isinstance(model, Model) | 
		
	
		
			
			|  |  |  | assert isinstance(val_data, Data) | 
		
	
		
			
			|  |  |  | assert isinstance(test_data, Data) | 
		
	
		
			
			|  |  |  | assert callable(loss) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.model = model | 
		
	
		
			
			|  |  |  | self.pos_batcher = pos_batcher | 
		
	
		
			
			|  |  |  | self.neg_batcher = neg_batcher | 
		
	
		
			
			|  |  |  | self.test_data = test_data | 
		
	
		
			
			|  |  |  | self.initial_repr = list(initial_repr) | 
		
	
		
			
			|  |  |  | self.max_epochs = int(num_epochs) | 
		
	
		
			
			|  |  |  | self.batch_size = int(batch_size) | 
		
	
		
			
			|  |  |  | self.loss = loss | 
		
	
		
			
			|  |  |  | self.lr = float(lr) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.pos_data = model.data | 
		
	
		
			
			|  |  |  | self.neg_data = negative_sample_data(model.data) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.pos_val_data = val_data | 
		
	
		
			
			|  |  |  | self.neg_val_data = negative_sample_data(val_data) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.batcher = DualBatcher(self.pos_data, self.neg_data, | 
		
	
		
			
			|  |  |  | batch_size=batch_size) | 
		
	
		
			
			|  |  |  | self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def run_epoch(self) -> None: | 
		
	
		
			
			|  |  |  | pos_it = iter(self.pos_batcher) | 
		
	
		
			
			|  |  |  | neg_it = iter(self.neg_batcher) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | while True: | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | pos_batch = next(pos_it) | 
		
	
		
			
			|  |  |  | neg_batch = next(neg_it) | 
		
	
		
			
			|  |  |  | except StopIteration: | 
		
	
		
			
			|  |  |  | break | 
		
	
		
			
			|  |  |  | if len(pos_batch.edges) != len(neg_batch.edges): | 
		
	
		
			
			|  |  |  | raise ValueError('Positive and negative batch should have same length') | 
		
	
		
			
			|  |  |  | loss_sum = 0. | 
		
	
		
			
			|  |  |  | for pos_batch, neg_batch in self.batcher: | 
		
	
		
			
			|  |  |  | batch = _merge_pos_neg_batches(pos_batch, neg_batch) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.opt.zero_grad() | 
		
	
		
			
			|  |  |  | last_layer_repr = self.model.convolve(self.initial_repr) | 
		
	
		
			
			|  |  |  | pred = self.model.decode(last_layer_repr, batch) | 
		
	
		
			
			|  |  |  | loss = self.loss(pred, batch.target_values) | 
		
	
		
			
			|  |  |  | loss.backward() | 
		
	
		
			
			|  |  |  | self.opt.step() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loss = loss.detach().cpu().item() | 
		
	
		
			
			|  |  |  | loss_sum += loss | 
		
	
		
			
			|  |  |  | print('loss:', loss) | 
		
	
		
			
			|  |  |  | return loss_sum | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def validate_epoch(self): | 
		
	
		
			
			|  |  |  | loss_sum = 0. | 
		
	
		
			
			|  |  |  | for pos_batch, neg_batch in self.val_batcher: | 
		
	
		
			
			|  |  |  | batch = _merge_pos_neg_batches(pos_batch, neg_batch) | 
		
	
		
			
			|  |  |  | with torch.no_grad(): | 
		
	
		
			
			|  |  |  | last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True) | 
		
	
		
			
			|  |  |  | pred = self.model.decode(last_layer_repr, batch, eval_mode=True) | 
		
	
		
			
			|  |  |  | loss = self.loss(pred, batch.target_values) | 
		
	
		
			
			|  |  |  | loss = loss.detach().cpu().item() | 
		
	
		
			
			|  |  |  | loss_sum += loss | 
		
	
		
			
			|  |  |  | return loss_sum | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def run(self) -> None: | 
		
	
		
			
			|  |  |  | best_loss = float('inf') | 
		
	
		
			
			|  |  |  | epochs_without_improvement = 0 | 
		
	
		
			
			|  |  |  | for epoch in range(self.max_epochs): | 
		
	
		
			
			|  |  |  | self.run_epoch() | 
		
	
		
			
			|  |  |  | print('Epoch', epoch) | 
		
	
		
			
			|  |  |  | loss_sum = self.run_epoch() | 
		
	
		
			
			|  |  |  | print('train loss_sum:', loss_sum) | 
		
	
		
			
			|  |  |  | loss_sum = self.validate_epoch() | 
		
	
		
			
			|  |  |  | print('val loss_sum:', loss_sum) | 
		
	
		
			
			|  |  |  | if loss_sum >= best_loss: | 
		
	
		
			
			|  |  |  | epochs_without_improvement += 1 | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | epochs_without_improvement = 0 | 
		
	
		
			
			|  |  |  | best_loss = loss_sum | 
		
	
		
			
			|  |  |  | if epochs_without_improvement == 2: | 
		
	
		
			
			|  |  |  | print('Early stopping after epoch', epoch, 'due to no improvement') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return (epoch, best_loss, loss_sum) |