diff --git a/src/icosagon/fastloop.py b/src/icosagon/fastloop.py new file mode 100644 index 0000000..f955932 --- /dev/null +++ b/src/icosagon/fastloop.py @@ -0,0 +1,166 @@ +from .fastmodel import FastModel +from .trainprep import PreparedData +import torch +from typing import Callable +from types import FunctionType +import time +import random + + +class FastBatcher(object): + def __init__(self, prep_d: PreparedData, batch_size: int, + shuffle: bool, generator: torch.Generator, + part_type: str) -> None: + + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instance of PreparedData') + + if not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + + if part_type not in ['train', 'val', 'test']: + raise ValueError('part_type must be set to train, val or test') + + self.prep_d = prep_d + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + self.generator = generator + self.part_type = part_type + + self.edges = None + self.targets = None + self.build() + + def build(self): + self.edges = [] + self.targets = [] + + for fam in self.prep_d.relation_families: + edges = [] + targets = [] + for i, rel in enumerate(fam.relation_types): + + edges_pos = getattr(rel.edges_pos, self.part_type) + edges_neg = getattr(rel.edges_neg, self.part_type) + edges_back_pos = getattr(rel.edges_back_pos, self.part_type) + edges_back_neg = getattr(rel.edges_back_neg, self.part_type) + + e = torch.cat([ edges_pos, + torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) + e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) + t = torch.ones(len(e)) + edges.append(e) + targets.append(t) + + e = torch.cat([ edges_neg, + torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) + e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) + t = torch.zeros(len(e)) + edges.append(e) + targets.append(t) + + edges = torch.cat(edges) + targets = torch.cat(targets) + + self.edges.append(edges) + self.targets.append(targets) + + # print(self.edges) + # print(self.targets) + + if self.shuffle: + self.shuffle_families() + + def shuffle_families(self): + for i in range(len(self.edges)): + edges = self.edges[i] + targets = self.targets[i] + order = torch.randperm(len(edges), generator=self.generator) + self.edges[i] = edges[order] + self.targets[i] = targets[order] + + def __iter__(self): + offsets = [ 0 for _ in self.edges ] + + while True: + choice = [ i for i in range(len(offsets)) \ + if offsets[i] < len(self.edges[i]) ] + if len(choice) == 0: + break + fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() + ofs = offsets[fam_idx] + edges = self.edges[fam_idx][ofs:ofs + self.batch_size] + targets = self.targets[fam_idx][ofs:ofs + self.batch_size] + offsets[fam_idx] += self.batch_size + yield (fam_idx, edges, targets) + + +class FastLoop(object): + def __init__( + self, + model: FastModel, + lr: float = 0.001, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ + torch.nn.functional.binary_cross_entropy_with_logits, + batch_size: int = 100, + shuffle: bool = True, + generator: torch.Generator = None) -> None: + + self._check_params(model, loss, generator) + + self.model = model + self.lr = float(lr) + self.loss = loss + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + self.generator = generator or torch.default_generator + + self.opt = None + + self.build() + + def _check_params(self, model, loss, generator): + if not isinstance(model, FastModel): + raise TypeError('model must be an instance of FastModel') + + if not isinstance(loss, FunctionType): + raise TypeError('loss must be a function') + + if generator is not None and not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + + def build(self) -> None: + opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.opt = opt + + def run_epoch(self): + prep_d = self.model.prep_d + + batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, + shuffle = self.shuffle, generator=self.generator) + # pred = self.model(None) + # n = len(list(iter(batch))) + loss_sum = 0 + for fam_idx, edges, targets in batcher: + self.opt.zero_grad() + pred = self.model(None) + + # process pred, get input and targets + input = pred[fam_idx][edges[:, 0], edges[:, 1]] + + loss = self.loss(input, targets) + loss.backward() + self.opt.step() + loss_sum += loss.detach().cpu().item() + return loss_sum + + + def train(self, max_epochs): + best_loss = None + best_epoch = None + for i in range(max_epochs): + loss = self.run_epoch() + if best_loss is None or loss < best_loss: + best_loss = loss + best_epoch = i + return loss, best_loss, best_epoch diff --git a/tests/icosagon/test_fastloop.py b/tests/icosagon/test_fastloop.py new file mode 100644 index 0000000..0afb285 --- /dev/null +++ b/tests/icosagon/test_fastloop.py @@ -0,0 +1,51 @@ +from icosagon.fastloop import FastBatcher, \ + FastModel +from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest +import torch + + +def test_fast_batcher_01(): + d = Data() + d.add_node_type('Gene', 5) + d.add_node_type('Drug', 3) + + fam = d.add_relation_family('Gene-Drug', 0, 1, True) + + adj_mat = torch.tensor([ + [ 1, 0, 1 ], + [ 0, 0, 1 ], + [ 0, 1, 0 ], + [ 1, 0, 0 ], + [ 0, 1, 1 ] + ], dtype=torch.float32).to_sparse() + fam.add_relation_type('Target', adj_mat) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + # print(prep_d.relation_families[0]) + + g = torch.Generator() + batcher = FastBatcher(prep_d, batch_size=3, shuffle=True, + generator=g, part_type='train') + + print(batcher.edges) + print(batcher.targets) + + edges_check = [ set() for _ in range(len(batcher.edges)) ] + + for fam_idx, edges, targets in batcher: + print(fam_idx, edges, targets) + for e in edges: + edges_check[fam_idx].add(tuple(e.tolist())) + + edges_check_2 = [ set() for _ in range(len(batcher.edges)) ] + for i, edges in enumerate(batcher.edges): + for e in edges: + edges_check_2[i].add(tuple(e.tolist())) + + assert edges_check == edges_check_2 + + +def test_fast_model_01(): + raise NotImplementedError