IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Start working on FastLoop.

master
Stanislaw Adaszewski pirms 4 gadiem
vecāks
revīzija
2eff854467
2 mainītis faili ar 217 papildinājumiem un 0 dzēšanām
  1. +166
    -0
      src/icosagon/fastloop.py
  2. +51
    -0
      tests/icosagon/test_fastloop.py

+ 166
- 0
src/icosagon/fastloop.py Parādīt failu

@@ -0,0 +1,166 @@
from .fastmodel import FastModel
from .trainprep import PreparedData
import torch
from typing import Callable
from types import FunctionType
import time
import random
class FastBatcher(object):
def __init__(self, prep_d: PreparedData, batch_size: int,
shuffle: bool, generator: torch.Generator,
part_type: str) -> None:
if not isinstance(prep_d, PreparedData):
raise TypeError('prep_d must be an instance of PreparedData')
if not isinstance(generator, torch.Generator):
raise TypeError('generator must be an instance of torch.Generator')
if part_type not in ['train', 'val', 'test']:
raise ValueError('part_type must be set to train, val or test')
self.prep_d = prep_d
self.batch_size = int(batch_size)
self.shuffle = bool(shuffle)
self.generator = generator
self.part_type = part_type
self.edges = None
self.targets = None
self.build()
def build(self):
self.edges = []
self.targets = []
for fam in self.prep_d.relation_families:
edges = []
targets = []
for i, rel in enumerate(fam.relation_types):
edges_pos = getattr(rel.edges_pos, self.part_type)
edges_neg = getattr(rel.edges_neg, self.part_type)
edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
e = torch.cat([ edges_pos,
torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
t = torch.ones(len(e))
edges.append(e)
targets.append(t)
e = torch.cat([ edges_neg,
torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
t = torch.zeros(len(e))
edges.append(e)
targets.append(t)
edges = torch.cat(edges)
targets = torch.cat(targets)
self.edges.append(edges)
self.targets.append(targets)
# print(self.edges)
# print(self.targets)
if self.shuffle:
self.shuffle_families()
def shuffle_families(self):
for i in range(len(self.edges)):
edges = self.edges[i]
targets = self.targets[i]
order = torch.randperm(len(edges), generator=self.generator)
self.edges[i] = edges[order]
self.targets[i] = targets[order]
def __iter__(self):
offsets = [ 0 for _ in self.edges ]
while True:
choice = [ i for i in range(len(offsets)) \
if offsets[i] < len(self.edges[i]) ]
if len(choice) == 0:
break
fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
ofs = offsets[fam_idx]
edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
offsets[fam_idx] += self.batch_size
yield (fam_idx, edges, targets)
class FastLoop(object):
def __init__(
self,
model: FastModel,
lr: float = 0.001,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits,
batch_size: int = 100,
shuffle: bool = True,
generator: torch.Generator = None) -> None:
self._check_params(model, loss, generator)
self.model = model
self.lr = float(lr)
self.loss = loss
self.batch_size = int(batch_size)
self.shuffle = bool(shuffle)
self.generator = generator or torch.default_generator
self.opt = None
self.build()
def _check_params(self, model, loss, generator):
if not isinstance(model, FastModel):
raise TypeError('model must be an instance of FastModel')
if not isinstance(loss, FunctionType):
raise TypeError('loss must be a function')
if generator is not None and not isinstance(generator, torch.Generator):
raise TypeError('generator must be an instance of torch.Generator')
def build(self) -> None:
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.opt = opt
def run_epoch(self):
prep_d = self.model.prep_d
batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
shuffle = self.shuffle, generator=self.generator)
# pred = self.model(None)
# n = len(list(iter(batch)))
loss_sum = 0
for fam_idx, edges, targets in batcher:
self.opt.zero_grad()
pred = self.model(None)
# process pred, get input and targets
input = pred[fam_idx][edges[:, 0], edges[:, 1]]
loss = self.loss(input, targets)
loss.backward()
self.opt.step()
loss_sum += loss.detach().cpu().item()
return loss_sum
def train(self, max_epochs):
best_loss = None
best_epoch = None
for i in range(max_epochs):
loss = self.run_epoch()
if best_loss is None or loss < best_loss:
best_loss = loss
best_epoch = i
return loss, best_loss, best_epoch

+ 51
- 0
tests/icosagon/test_fastloop.py Parādīt failu

@@ -0,0 +1,51 @@
from icosagon.fastloop import FastBatcher, \
FastModel
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
import torch
def test_fast_batcher_01():
d = Data()
d.add_node_type('Gene', 5)
d.add_node_type('Drug', 3)
fam = d.add_relation_family('Gene-Drug', 0, 1, True)
adj_mat = torch.tensor([
[ 1, 0, 1 ],
[ 0, 0, 1 ],
[ 0, 1, 0 ],
[ 1, 0, 0 ],
[ 0, 1, 1 ]
], dtype=torch.float32).to_sparse()
fam.add_relation_type('Target', adj_mat)
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
# print(prep_d.relation_families[0])
g = torch.Generator()
batcher = FastBatcher(prep_d, batch_size=3, shuffle=True,
generator=g, part_type='train')
print(batcher.edges)
print(batcher.targets)
edges_check = [ set() for _ in range(len(batcher.edges)) ]
for fam_idx, edges, targets in batcher:
print(fam_idx, edges, targets)
for e in edges:
edges_check[fam_idx].add(tuple(e.tolist()))
edges_check_2 = [ set() for _ in range(len(batcher.edges)) ]
for i, edges in enumerate(batcher.edges):
for e in edges:
edges_check_2[i].add(tuple(e.tolist()))
assert edges_check == edges_check_2
def test_fast_model_01():
raise NotImplementedError

Notiek ielāde…
Atcelt
Saglabāt