IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add triacontagon.TrainLoop.

master
Stanislaw Adaszewski 3 years ago
parent
commit
4c8c06c63c
2 changed files with 96 additions and 30 deletions
  1. +90
    -25
      src/triacontagon/loop.py
  2. +6
    -5
      src/triacontagon/model.py

+ 90
- 25
src/triacontagon/loop.py View File

@@ -1,40 +1,105 @@
from .model import Model
from .batch import Batcher
from .sampling import negative_sample_data
from .data import Data
class TrainLoop(object):
def __init__(self, model: Model,
pos_batcher: Batcher,
neg_batcher: Batcher,
max_epochs: int = 50) -> None:
def _merge_pos_neg_batches(pos_batch, neg_batch):
assert len(pos_batch.edges) == len(neg_batch.edges)
assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
assert pos_batch.relation_type_index == neg_batch.relation_type_index
if not isinstance(model, Model):
raise TypeError('model must be an instance of Model')
batch = TrainingBatch(pos_batch.vertex_type_row,
pos_batch.vertex_type_column,
pos_batch.relation_type_index,
torch.cat([ pos_batch.edges, neg_batch.edges ]),
torch.cat([
torch.ones(len(pos_batch.edges)),
torch.zeros(len(neg_batch.edges))
]))
return batch
if not isinstance(pos_batcher, Batcher):
raise TypeError('pos_batcher must be an instance of Batcher')
if not isinstance(neg_batcher, Batcher):
raise TypeError('neg_batcher must be an instance of Batcher')
class TrainLoop(object):
def __init__(self, model: Model,
val_data: Data, test_data: Data,
initial_repr: List[torch.Tensor],
max_epochs: int = 50,
batch_size: int = 512,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
torch.nn.functional.binary_cross_entropy_with_logits,
lr: float = 0.001) -> None:
assert isinstance(model, Model)
assert isinstance(val_data, Data)
assert isinstance(test_data, Data)
assert callable(loss)
self.model = model
self.pos_batcher = pos_batcher
self.neg_batcher = neg_batcher
self.test_data = test_data
self.initial_repr = list(initial_repr)
self.max_epochs = int(num_epochs)
self.batch_size = int(batch_size)
self.loss = loss
self.lr = float(lr)
self.pos_data = model.data
self.neg_data = negative_sample_data(model.data)
self.pos_val_data = val_data
self.neg_val_data = negative_sample_data(val_data)
self.batcher = DualBatcher(self.pos_data, self.neg_data,
batch_size=batch_size)
self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
def run_epoch(self) -> None:
pos_it = iter(self.pos_batcher)
neg_it = iter(self.neg_batcher)
while True:
try:
pos_batch = next(pos_it)
neg_batch = next(neg_it)
except StopIteration:
break
if len(pos_batch.edges) != len(neg_batch.edges):
raise ValueError('Positive and negative batch should have same length')
loss_sum = 0.
for pos_batch, neg_batch in self.batcher:
batch = _merge_pos_neg_batches(pos_batch, neg_batch)
self.opt.zero_grad()
last_layer_repr = self.model.convolve(self.initial_repr)
pred = self.model.decode(last_layer_repr, batch)
loss = self.loss(pred, batch.target_values)
loss.backward()
self.opt.step()
loss = loss.detach().cpu().item()
loss_sum += loss
print('loss:', loss)
return loss_sum
def validate_epoch(self):
loss_sum = 0.
for pos_batch, neg_batch in self.val_batcher:
batch = _merge_pos_neg_batches(pos_batch, neg_batch)
with torch.no_grad():
last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
loss = self.loss(pred, batch.target_values)
loss = loss.detach().cpu().item()
loss_sum += loss
return loss_sum
def run(self) -> None:
best_loss = float('inf')
epochs_without_improvement = 0
for epoch in range(self.max_epochs):
self.run_epoch()
print('Epoch', epoch)
loss_sum = self.run_epoch()
print('train loss_sum:', loss_sum)
loss_sum = self.validate_epoch()
print('val loss_sum:', loss_sum)
if loss_sum >= best_loss:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
best_loss = loss_sum
if epochs_without_improvement == 2:
print('Early stopping after epoch', epoch, 'due to no improvement')
return (epoch, best_loss, loss_sum)

+ 6
- 5
src/triacontagon/model.py View File

@@ -130,7 +130,7 @@ class Model(torch.nn.Module):
torch.nn.Parameter(local_variation[i])
def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
def convolve(self, in_layer_repr: List[torch.Tensor], eval_mode=False) -> \
List[torch.Tensor]:
cur_layer_repr = in_layer_repr
@@ -145,7 +145,7 @@ class Model(torch.nn.Module):
num_relation_types = len(et.adjacency_matrices)
x = cur_layer_repr[vt_col]
if self.keep_prob != 1:
if self.keep_prob != 1 and not eval_mode:
x = dropout(x, self.keep_prob)
# print('a, Layer:', i, 'x.shape:', x.shape)
@@ -176,7 +176,7 @@ class Model(torch.nn.Module):
return cur_layer_repr
def decode(self, last_layer_repr: List[torch.Tensor],
batch: TrainingBatch) -> torch.Tensor:
batch: TrainingBatch, eval_mode=False) -> torch.Tensor:
vt_row = batch.vertex_type_row
vt_col = batch.vertex_type_column
@@ -195,8 +195,9 @@ class Model(torch.nn.Module):
in_row = in_row[batch.edges[:, 0]]
in_col = in_col[batch.edges[:, 1]]
in_row = dropout(in_row, self.keep_prob)
in_col = dropout(in_col, self.keep_prob)
if self.keep_prob != 1 and not eval_mode:
in_row = dropout(in_row, self.keep_prob)
in_col = dropout(in_col, self.keep_prob)
# in_row = in_row.to_dense()
# in_col = in_col.to_dense()


Loading…
Cancel
Save