|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from .fastmodel import FastModel
- from .trainprep import PreparedData
- import torch
- from typing import Callable
- from types import FunctionType
- import time
- import random
-
-
- class FastBatcher(object):
- def __init__(self, prep_d: PreparedData, batch_size: int,
- shuffle: bool, generator: torch.Generator,
- part_type: str) -> None:
-
- if not isinstance(prep_d, PreparedData):
- raise TypeError('prep_d must be an instance of PreparedData')
-
- if not isinstance(generator, torch.Generator):
- raise TypeError('generator must be an instance of torch.Generator')
-
- if part_type not in ['train', 'val', 'test']:
- raise ValueError('part_type must be set to train, val or test')
-
- self.prep_d = prep_d
- self.batch_size = int(batch_size)
- self.shuffle = bool(shuffle)
- self.generator = generator
- self.part_type = part_type
-
- self.edges = None
- self.targets = None
- self.build()
-
- def build(self):
- self.edges = []
- self.targets = []
-
- for fam in self.prep_d.relation_families:
- edges = []
- targets = []
- for i, rel in enumerate(fam.relation_types):
-
- edges_pos = getattr(rel.edges_pos, self.part_type)
- edges_neg = getattr(rel.edges_neg, self.part_type)
- edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
- edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
-
- e = torch.cat([ edges_pos,
- torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
- e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
- t = torch.ones(len(e))
- edges.append(e)
- targets.append(t)
-
- e = torch.cat([ edges_neg,
- torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
- e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
- t = torch.zeros(len(e))
- edges.append(e)
- targets.append(t)
-
- edges = torch.cat(edges)
- targets = torch.cat(targets)
-
- self.edges.append(edges)
- self.targets.append(targets)
-
- # print(self.edges)
- # print(self.targets)
-
- if self.shuffle:
- self.shuffle_families()
-
- def shuffle_families(self):
- for i in range(len(self.edges)):
- edges = self.edges[i]
- targets = self.targets[i]
- order = torch.randperm(len(edges), generator=self.generator)
- self.edges[i] = edges[order]
- self.targets[i] = targets[order]
-
- def __iter__(self):
- offsets = [ 0 for _ in self.edges ]
-
- while True:
- choice = [ i for i in range(len(offsets)) \
- if offsets[i] < len(self.edges[i]) ]
- if len(choice) == 0:
- break
- fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
- ofs = offsets[fam_idx]
- edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
- targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
- offsets[fam_idx] += self.batch_size
- yield (fam_idx, edges, targets)
-
-
- class FastLoop(object):
- def __init__(
- self,
- model: FastModel,
- lr: float = 0.001,
- loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
- torch.nn.functional.binary_cross_entropy_with_logits,
- batch_size: int = 100,
- shuffle: bool = True,
- generator: torch.Generator = None) -> None:
-
- self._check_params(model, loss, generator)
-
- self.model = model
- self.lr = float(lr)
- self.loss = loss
- self.batch_size = int(batch_size)
- self.shuffle = bool(shuffle)
- self.generator = generator or torch.default_generator
-
- self.opt = None
-
- self.build()
-
- def _check_params(self, model, loss, generator):
- if not isinstance(model, FastModel):
- raise TypeError('model must be an instance of FastModel')
-
- if not isinstance(loss, FunctionType):
- raise TypeError('loss must be a function')
-
- if generator is not None and not isinstance(generator, torch.Generator):
- raise TypeError('generator must be an instance of torch.Generator')
-
- def build(self) -> None:
- opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
- self.opt = opt
-
- def run_epoch(self):
- prep_d = self.model.prep_d
-
- batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
- shuffle = self.shuffle, generator=self.generator)
- # pred = self.model(None)
- # n = len(list(iter(batch)))
- loss_sum = 0
- for fam_idx, edges, targets in batcher:
- self.opt.zero_grad()
- pred = self.model(None)
-
- # process pred, get input and targets
- input = pred[fam_idx][edges[:, 0], edges[:, 1]]
-
- loss = self.loss(input, targets)
- loss.backward()
- self.opt.step()
- loss_sum += loss.detach().cpu().item()
- return loss_sum
-
-
- def train(self, max_epochs):
- best_loss = None
- best_epoch = None
- for i in range(max_epochs):
- loss = self.run_epoch()
- if best_loss is None or loss < best_loss:
- best_loss = loss
- best_epoch = i
- return loss, best_loss, best_epoch
|