| 
				
				
				
				 | 
			
			 | 
			@@ -1,40 +1,105 @@ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .model import Model
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .batch import Batcher
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .sampling import negative_sample_data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from .data import Data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class TrainLoop(object):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, model: Model,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        pos_batcher: Batcher,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        neg_batcher: Batcher,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        max_epochs: int = 50) -> None:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def _merge_pos_neg_batches(pos_batch, neg_batch):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert len(pos_batch.edges) == len(neg_batch.edges)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    assert pos_batch.relation_type_index == neg_batch.relation_type_index
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if not isinstance(model, Model):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise TypeError('model must be an instance of Model')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    batch = TrainingBatch(pos_batch.vertex_type_row,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        pos_batch.vertex_type_column,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        pos_batch.relation_type_index,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        torch.cat([ pos_batch.edges, neg_batch.edges ]),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        torch.cat([
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            torch.ones(len(pos_batch.edges)),
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            torch.zeros(len(neg_batch.edges))
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        ]))
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    return batch
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if not isinstance(pos_batcher, Batcher):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise TypeError('pos_batcher must be an instance of Batcher')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if not isinstance(neg_batcher, Batcher):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise TypeError('neg_batcher must be an instance of Batcher')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			class TrainLoop(object):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def __init__(self, model: Model,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        val_data: Data, test_data: Data,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        initial_repr: List[torch.Tensor],
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        max_epochs: int = 50,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        batch_size: int = 512,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            torch.nn.functional.binary_cross_entropy_with_logits,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        lr: float = 0.001) -> None:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        assert isinstance(model, Model)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        assert isinstance(val_data, Data)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        assert isinstance(test_data, Data)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        assert callable(loss)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.model = model
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.pos_batcher = pos_batcher
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.neg_batcher = neg_batcher
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.test_data = test_data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.initial_repr = list(initial_repr)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.max_epochs = int(num_epochs)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.batch_size = int(batch_size)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.loss = loss
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.lr = float(lr)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.pos_data = model.data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.neg_data = negative_sample_data(model.data)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.pos_val_data = val_data
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.neg_val_data = negative_sample_data(val_data)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.batcher = DualBatcher(self.pos_data, self.neg_data,
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            batch_size=batch_size)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def run_epoch(self) -> None:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        pos_it = iter(self.pos_batcher)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        neg_it = iter(self.neg_batcher)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        while True:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            try:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                pos_batch = next(pos_it)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                neg_batch = next(neg_it)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            except StopIteration:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                break
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            if len(pos_batch.edges) != len(neg_batch.edges):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                raise ValueError('Positive and negative batch should have same length')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        loss_sum = 0.
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for pos_batch, neg_batch in self.batcher:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            batch = _merge_pos_neg_batches(pos_batch, neg_batch)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            self.opt.zero_grad()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            last_layer_repr = self.model.convolve(self.initial_repr)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            pred = self.model.decode(last_layer_repr, batch)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss = self.loss(pred, batch.target_values)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss.backward()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            self.opt.step()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss = loss.detach().cpu().item()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss_sum += loss
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            print('loss:', loss)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return loss_sum
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def validate_epoch(self):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        loss_sum = 0.
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for pos_batch, neg_batch in self.val_batcher:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            batch = _merge_pos_neg_batches(pos_batch, neg_batch)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            with torch.no_grad():
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                loss = self.loss(pred, batch.target_values)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                loss = loss.detach().cpu().item()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                loss_sum += loss
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return loss_sum
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def run(self) -> None:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        best_loss = float('inf')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        epochs_without_improvement = 0
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for epoch in range(self.max_epochs):
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            self.run_epoch()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            print('Epoch', epoch)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss_sum = self.run_epoch()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            print('train loss_sum:', loss_sum)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss_sum = self.validate_epoch()
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            print('val loss_sum:', loss_sum)
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            if loss_sum >= best_loss:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                epochs_without_improvement += 1
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            else:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                epochs_without_improvement = 0
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                best_loss = loss_sum
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            if epochs_without_improvement == 2:
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                print('Early stopping after epoch', epoch, 'due to no improvement')
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return (epoch, best_loss, loss_sum)
 |