from .fastmodel import FastModel from .trainprep import PreparedData import torch from typing import Callable from types import FunctionType import time import random class FastBatcher(object): def __init__(self, prep_d: PreparedData, batch_size: int, shuffle: bool, generator: torch.Generator, part_type: str) -> None: if not isinstance(prep_d, PreparedData): raise TypeError('prep_d must be an instance of PreparedData') if not isinstance(generator, torch.Generator): raise TypeError('generator must be an instance of torch.Generator') if part_type not in ['train', 'val', 'test']: raise ValueError('part_type must be set to train, val or test') self.prep_d = prep_d self.batch_size = int(batch_size) self.shuffle = bool(shuffle) self.generator = generator self.part_type = part_type self.edges = None self.targets = None self.build() def build(self): self.edges = [] self.targets = [] for fam in self.prep_d.relation_families: edges = [] targets = [] for i, rel in enumerate(fam.relation_types): edges_pos = getattr(rel.edges_pos, self.part_type) edges_neg = getattr(rel.edges_neg, self.part_type) edges_back_pos = getattr(rel.edges_back_pos, self.part_type) edges_back_neg = getattr(rel.edges_back_neg, self.part_type) e = torch.cat([ edges_pos, torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) t = torch.ones(len(e)) edges.append(e) targets.append(t) e = torch.cat([ edges_neg, torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) t = torch.zeros(len(e)) edges.append(e) targets.append(t) edges = torch.cat(edges) targets = torch.cat(targets) self.edges.append(edges) self.targets.append(targets) # print(self.edges) # print(self.targets) if self.shuffle: self.shuffle_families() def shuffle_families(self): for i in range(len(self.edges)): edges = self.edges[i] targets = self.targets[i] order = torch.randperm(len(edges), generator=self.generator) self.edges[i] = edges[order] self.targets[i] = targets[order] def __iter__(self): offsets = [ 0 for _ in self.edges ] while True: choice = [ i for i in range(len(offsets)) \ if offsets[i] < len(self.edges[i]) ] if len(choice) == 0: break fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() ofs = offsets[fam_idx] edges = self.edges[fam_idx][ofs:ofs + self.batch_size] targets = self.targets[fam_idx][ofs:ofs + self.batch_size] offsets[fam_idx] += self.batch_size yield (fam_idx, edges, targets) class FastLoop(object): def __init__( self, model: FastModel, lr: float = 0.001, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ torch.nn.functional.binary_cross_entropy_with_logits, batch_size: int = 100, shuffle: bool = True, generator: torch.Generator = None) -> None: self._check_params(model, loss, generator) self.model = model self.lr = float(lr) self.loss = loss self.batch_size = int(batch_size) self.shuffle = bool(shuffle) self.generator = generator or torch.default_generator self.opt = None self.build() def _check_params(self, model, loss, generator): if not isinstance(model, FastModel): raise TypeError('model must be an instance of FastModel') if not isinstance(loss, FunctionType): raise TypeError('loss must be a function') if generator is not None and not isinstance(generator, torch.Generator): raise TypeError('generator must be an instance of torch.Generator') def build(self) -> None: opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.opt = opt def run_epoch(self): prep_d = self.model.prep_d batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, shuffle = self.shuffle, generator=self.generator) # pred = self.model(None) # n = len(list(iter(batch))) loss_sum = 0 for fam_idx, edges, targets in batcher: self.opt.zero_grad() pred = self.model(None) # process pred, get input and targets input = pred[fam_idx][edges[:, 0], edges[:, 1]] loss = self.loss(input, targets) loss.backward() self.opt.step() loss_sum += loss.detach().cpu().item() return loss_sum def train(self, max_epochs): best_loss = None best_epoch = None for i in range(max_epochs): loss = self.run_epoch() if best_loss is None or loss < best_loss: best_loss = loss best_epoch = i return loss, best_loss, best_epoch