|
@@ -0,0 +1,166 @@ |
|
|
|
|
|
from .fastmodel import FastModel
|
|
|
|
|
|
from .trainprep import PreparedData
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from typing import Callable
|
|
|
|
|
|
from types import FunctionType
|
|
|
|
|
|
import time
|
|
|
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastBatcher(object):
|
|
|
|
|
|
def __init__(self, prep_d: PreparedData, batch_size: int,
|
|
|
|
|
|
shuffle: bool, generator: torch.Generator,
|
|
|
|
|
|
part_type: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(prep_d, PreparedData):
|
|
|
|
|
|
raise TypeError('prep_d must be an instance of PreparedData')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(generator, torch.Generator):
|
|
|
|
|
|
raise TypeError('generator must be an instance of torch.Generator')
|
|
|
|
|
|
|
|
|
|
|
|
if part_type not in ['train', 'val', 'test']:
|
|
|
|
|
|
raise ValueError('part_type must be set to train, val or test')
|
|
|
|
|
|
|
|
|
|
|
|
self.prep_d = prep_d
|
|
|
|
|
|
self.batch_size = int(batch_size)
|
|
|
|
|
|
self.shuffle = bool(shuffle)
|
|
|
|
|
|
self.generator = generator
|
|
|
|
|
|
self.part_type = part_type
|
|
|
|
|
|
|
|
|
|
|
|
self.edges = None
|
|
|
|
|
|
self.targets = None
|
|
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
|
|
self.edges = []
|
|
|
|
|
|
self.targets = []
|
|
|
|
|
|
|
|
|
|
|
|
for fam in self.prep_d.relation_families:
|
|
|
|
|
|
edges = []
|
|
|
|
|
|
targets = []
|
|
|
|
|
|
for i, rel in enumerate(fam.relation_types):
|
|
|
|
|
|
|
|
|
|
|
|
edges_pos = getattr(rel.edges_pos, self.part_type)
|
|
|
|
|
|
edges_neg = getattr(rel.edges_neg, self.part_type)
|
|
|
|
|
|
edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
|
|
|
|
|
|
edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
|
|
|
|
|
|
|
|
|
|
|
|
e = torch.cat([ edges_pos,
|
|
|
|
|
|
torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
|
|
|
|
|
|
e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
|
|
|
|
|
|
t = torch.ones(len(e))
|
|
|
|
|
|
edges.append(e)
|
|
|
|
|
|
targets.append(t)
|
|
|
|
|
|
|
|
|
|
|
|
e = torch.cat([ edges_neg,
|
|
|
|
|
|
torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
|
|
|
|
|
|
e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
|
|
|
|
|
|
t = torch.zeros(len(e))
|
|
|
|
|
|
edges.append(e)
|
|
|
|
|
|
targets.append(t)
|
|
|
|
|
|
|
|
|
|
|
|
edges = torch.cat(edges)
|
|
|
|
|
|
targets = torch.cat(targets)
|
|
|
|
|
|
|
|
|
|
|
|
self.edges.append(edges)
|
|
|
|
|
|
self.targets.append(targets)
|
|
|
|
|
|
|
|
|
|
|
|
# print(self.edges)
|
|
|
|
|
|
# print(self.targets)
|
|
|
|
|
|
|
|
|
|
|
|
if self.shuffle:
|
|
|
|
|
|
self.shuffle_families()
|
|
|
|
|
|
|
|
|
|
|
|
def shuffle_families(self):
|
|
|
|
|
|
for i in range(len(self.edges)):
|
|
|
|
|
|
edges = self.edges[i]
|
|
|
|
|
|
targets = self.targets[i]
|
|
|
|
|
|
order = torch.randperm(len(edges), generator=self.generator)
|
|
|
|
|
|
self.edges[i] = edges[order]
|
|
|
|
|
|
self.targets[i] = targets[order]
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
|
offsets = [ 0 for _ in self.edges ]
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
choice = [ i for i in range(len(offsets)) \
|
|
|
|
|
|
if offsets[i] < len(self.edges[i]) ]
|
|
|
|
|
|
if len(choice) == 0:
|
|
|
|
|
|
break
|
|
|
|
|
|
fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
|
|
|
|
|
|
ofs = offsets[fam_idx]
|
|
|
|
|
|
edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
|
|
|
|
|
|
targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
|
|
|
|
|
|
offsets[fam_idx] += self.batch_size
|
|
|
|
|
|
yield (fam_idx, edges, targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastLoop(object):
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model: FastModel,
|
|
|
|
|
|
lr: float = 0.001,
|
|
|
|
|
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
|
|
|
|
|
|
torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
|
|
batch_size: int = 100,
|
|
|
|
|
|
shuffle: bool = True,
|
|
|
|
|
|
generator: torch.Generator = None) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
self._check_params(model, loss, generator)
|
|
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
self.lr = float(lr)
|
|
|
|
|
|
self.loss = loss
|
|
|
|
|
|
self.batch_size = int(batch_size)
|
|
|
|
|
|
self.shuffle = bool(shuffle)
|
|
|
|
|
|
self.generator = generator or torch.default_generator
|
|
|
|
|
|
|
|
|
|
|
|
self.opt = None
|
|
|
|
|
|
|
|
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
|
|
|
def _check_params(self, model, loss, generator):
|
|
|
|
|
|
if not isinstance(model, FastModel):
|
|
|
|
|
|
raise TypeError('model must be an instance of FastModel')
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(loss, FunctionType):
|
|
|
|
|
|
raise TypeError('loss must be a function')
|
|
|
|
|
|
|
|
|
|
|
|
if generator is not None and not isinstance(generator, torch.Generator):
|
|
|
|
|
|
raise TypeError('generator must be an instance of torch.Generator')
|
|
|
|
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
|
|
|
|
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
|
|
|
|
|
|
self.opt = opt
|
|
|
|
|
|
|
|
|
|
|
|
def run_epoch(self):
|
|
|
|
|
|
prep_d = self.model.prep_d
|
|
|
|
|
|
|
|
|
|
|
|
batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
|
|
|
|
|
|
shuffle = self.shuffle, generator=self.generator)
|
|
|
|
|
|
# pred = self.model(None)
|
|
|
|
|
|
# n = len(list(iter(batch)))
|
|
|
|
|
|
loss_sum = 0
|
|
|
|
|
|
for fam_idx, edges, targets in batcher:
|
|
|
|
|
|
self.opt.zero_grad()
|
|
|
|
|
|
pred = self.model(None)
|
|
|
|
|
|
|
|
|
|
|
|
# process pred, get input and targets
|
|
|
|
|
|
input = pred[fam_idx][edges[:, 0], edges[:, 1]]
|
|
|
|
|
|
|
|
|
|
|
|
loss = self.loss(input, targets)
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
self.opt.step()
|
|
|
|
|
|
loss_sum += loss.detach().cpu().item()
|
|
|
|
|
|
return loss_sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, max_epochs):
|
|
|
|
|
|
best_loss = None
|
|
|
|
|
|
best_epoch = None
|
|
|
|
|
|
for i in range(max_epochs):
|
|
|
|
|
|
loss = self.run_epoch()
|
|
|
|
|
|
if best_loss is None or loss < best_loss:
|
|
|
|
|
|
best_loss = loss
|
|
|
|
|
|
best_epoch = i
|
|
|
|
|
|
return loss, best_loss, best_epoch
|