|  |  | @@ -0,0 +1,166 @@ | 
		
	
		
			
			|  |  |  | from .fastmodel import FastModel | 
		
	
		
			
			|  |  |  | from .trainprep import PreparedData | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | from typing import Callable | 
		
	
		
			
			|  |  |  | from types import FunctionType | 
		
	
		
			
			|  |  |  | import time | 
		
	
		
			
			|  |  |  | import random | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class FastBatcher(object): | 
		
	
		
			
			|  |  |  | def __init__(self, prep_d: PreparedData, batch_size: int, | 
		
	
		
			
			|  |  |  | shuffle: bool, generator: torch.Generator, | 
		
	
		
			
			|  |  |  | part_type: str) -> None: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(prep_d, PreparedData): | 
		
	
		
			
			|  |  |  | raise TypeError('prep_d must be an instance of PreparedData') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(generator, torch.Generator): | 
		
	
		
			
			|  |  |  | raise TypeError('generator must be an instance of torch.Generator') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if part_type not in ['train', 'val', 'test']: | 
		
	
		
			
			|  |  |  | raise ValueError('part_type must be set to train, val or test') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.prep_d = prep_d | 
		
	
		
			
			|  |  |  | self.batch_size = int(batch_size) | 
		
	
		
			
			|  |  |  | self.shuffle = bool(shuffle) | 
		
	
		
			
			|  |  |  | self.generator = generator | 
		
	
		
			
			|  |  |  | self.part_type = part_type | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.edges = None | 
		
	
		
			
			|  |  |  | self.targets = None | 
		
	
		
			
			|  |  |  | self.build() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def build(self): | 
		
	
		
			
			|  |  |  | self.edges = [] | 
		
	
		
			
			|  |  |  | self.targets = [] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for fam in self.prep_d.relation_families: | 
		
	
		
			
			|  |  |  | edges = [] | 
		
	
		
			
			|  |  |  | targets = [] | 
		
	
		
			
			|  |  |  | for i, rel in enumerate(fam.relation_types): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | edges_pos = getattr(rel.edges_pos, self.part_type) | 
		
	
		
			
			|  |  |  | edges_neg = getattr(rel.edges_neg, self.part_type) | 
		
	
		
			
			|  |  |  | edges_back_pos = getattr(rel.edges_back_pos, self.part_type) | 
		
	
		
			
			|  |  |  | edges_back_neg = getattr(rel.edges_back_neg, self.part_type) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | e = torch.cat([ edges_pos, | 
		
	
		
			
			|  |  |  | torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) | 
		
	
		
			
			|  |  |  | e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) | 
		
	
		
			
			|  |  |  | t = torch.ones(len(e)) | 
		
	
		
			
			|  |  |  | edges.append(e) | 
		
	
		
			
			|  |  |  | targets.append(t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | e = torch.cat([ edges_neg, | 
		
	
		
			
			|  |  |  | torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) | 
		
	
		
			
			|  |  |  | e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) | 
		
	
		
			
			|  |  |  | t = torch.zeros(len(e)) | 
		
	
		
			
			|  |  |  | edges.append(e) | 
		
	
		
			
			|  |  |  | targets.append(t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | edges = torch.cat(edges) | 
		
	
		
			
			|  |  |  | targets = torch.cat(targets) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.edges.append(edges) | 
		
	
		
			
			|  |  |  | self.targets.append(targets) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # print(self.edges) | 
		
	
		
			
			|  |  |  | # print(self.targets) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if self.shuffle: | 
		
	
		
			
			|  |  |  | self.shuffle_families() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def shuffle_families(self): | 
		
	
		
			
			|  |  |  | for i in range(len(self.edges)): | 
		
	
		
			
			|  |  |  | edges = self.edges[i] | 
		
	
		
			
			|  |  |  | targets = self.targets[i] | 
		
	
		
			
			|  |  |  | order = torch.randperm(len(edges), generator=self.generator) | 
		
	
		
			
			|  |  |  | self.edges[i] = edges[order] | 
		
	
		
			
			|  |  |  | self.targets[i] = targets[order] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __iter__(self): | 
		
	
		
			
			|  |  |  | offsets = [ 0 for _ in self.edges ] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | while True: | 
		
	
		
			
			|  |  |  | choice = [ i for i in range(len(offsets)) \ | 
		
	
		
			
			|  |  |  | if offsets[i] < len(self.edges[i]) ] | 
		
	
		
			
			|  |  |  | if len(choice) == 0: | 
		
	
		
			
			|  |  |  | break | 
		
	
		
			
			|  |  |  | fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() | 
		
	
		
			
			|  |  |  | ofs = offsets[fam_idx] | 
		
	
		
			
			|  |  |  | edges = self.edges[fam_idx][ofs:ofs + self.batch_size] | 
		
	
		
			
			|  |  |  | targets = self.targets[fam_idx][ofs:ofs + self.batch_size] | 
		
	
		
			
			|  |  |  | offsets[fam_idx] += self.batch_size | 
		
	
		
			
			|  |  |  | yield (fam_idx, edges, targets) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class FastLoop(object): | 
		
	
		
			
			|  |  |  | def __init__( | 
		
	
		
			
			|  |  |  | self, | 
		
	
		
			
			|  |  |  | model: FastModel, | 
		
	
		
			
			|  |  |  | lr: float = 0.001, | 
		
	
		
			
			|  |  |  | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | 
		
	
		
			
			|  |  |  | torch.nn.functional.binary_cross_entropy_with_logits, | 
		
	
		
			
			|  |  |  | batch_size: int = 100, | 
		
	
		
			
			|  |  |  | shuffle: bool = True, | 
		
	
		
			
			|  |  |  | generator: torch.Generator = None) -> None: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self._check_params(model, loss, generator) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.model = model | 
		
	
		
			
			|  |  |  | self.lr = float(lr) | 
		
	
		
			
			|  |  |  | self.loss = loss | 
		
	
		
			
			|  |  |  | self.batch_size = int(batch_size) | 
		
	
		
			
			|  |  |  | self.shuffle = bool(shuffle) | 
		
	
		
			
			|  |  |  | self.generator = generator or torch.default_generator | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.opt = None | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.build() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _check_params(self, model, loss, generator): | 
		
	
		
			
			|  |  |  | if not isinstance(model, FastModel): | 
		
	
		
			
			|  |  |  | raise TypeError('model must be an instance of FastModel') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(loss, FunctionType): | 
		
	
		
			
			|  |  |  | raise TypeError('loss must be a function') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if generator is not None and not isinstance(generator, torch.Generator): | 
		
	
		
			
			|  |  |  | raise TypeError('generator must be an instance of torch.Generator') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def build(self) -> None: | 
		
	
		
			
			|  |  |  | opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | 
		
	
		
			
			|  |  |  | self.opt = opt | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def run_epoch(self): | 
		
	
		
			
			|  |  |  | prep_d = self.model.prep_d | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, | 
		
	
		
			
			|  |  |  | shuffle = self.shuffle, generator=self.generator) | 
		
	
		
			
			|  |  |  | # pred = self.model(None) | 
		
	
		
			
			|  |  |  | # n = len(list(iter(batch))) | 
		
	
		
			
			|  |  |  | loss_sum = 0 | 
		
	
		
			
			|  |  |  | for fam_idx, edges, targets in batcher: | 
		
	
		
			
			|  |  |  | self.opt.zero_grad() | 
		
	
		
			
			|  |  |  | pred = self.model(None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # process pred, get input and targets | 
		
	
		
			
			|  |  |  | input = pred[fam_idx][edges[:, 0], edges[:, 1]] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loss = self.loss(input, targets) | 
		
	
		
			
			|  |  |  | loss.backward() | 
		
	
		
			
			|  |  |  | self.opt.step() | 
		
	
		
			
			|  |  |  | loss_sum += loss.detach().cpu().item() | 
		
	
		
			
			|  |  |  | return loss_sum | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def train(self, max_epochs): | 
		
	
		
			
			|  |  |  | best_loss = None | 
		
	
		
			
			|  |  |  | best_epoch = None | 
		
	
		
			
			|  |  |  | for i in range(max_epochs): | 
		
	
		
			
			|  |  |  | loss = self.run_epoch() | 
		
	
		
			
			|  |  |  | if best_loss is None or loss < best_loss: | 
		
	
		
			
			|  |  |  | best_loss = loss | 
		
	
		
			
			|  |  |  | best_epoch = i | 
		
	
		
			
			|  |  |  | return loss, best_loss, best_epoch |