IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.9KB

  1. from .model import Model, \
  2. TrainingBatch
  3. from .batch import Batcher
  4. from .sampling import negative_sample_data
  5. from .data import Data
  6. import torch
  7. from typing import List, \
  8. Callable
  9. def _merge_pos_neg_batches(pos_batch, neg_batch):
  10. assert len(pos_batch.edges) == len(neg_batch.edges)
  11. assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
  12. assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
  13. assert pos_batch.relation_type_index == neg_batch.relation_type_index
  14. batch = TrainingBatch(pos_batch.vertex_type_row,
  15. pos_batch.vertex_type_column,
  16. pos_batch.relation_type_index,
  17. torch.cat([ pos_batch.edges, neg_batch.edges ]),
  18. torch.cat([
  19. torch.ones(len(pos_batch.edges)),
  20. torch.zeros(len(neg_batch.edges))
  21. ]))
  22. return batch
  23. class TrainLoop(object):
  24. def __init__(self, model: Model,
  25. val_data: Data, test_data: Data,
  26. initial_repr: List[torch.Tensor],
  27. max_epochs: int = 50,
  28. batch_size: int = 512,
  29. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  30. torch.nn.functional.binary_cross_entropy_with_logits,
  31. lr: float = 0.001) -> None:
  32. assert isinstance(model, Model)
  33. assert isinstance(val_data, Data)
  34. assert isinstance(test_data, Data)
  35. assert callable(loss)
  36. self.model = model
  37. self.test_data = test_data
  38. self.initial_repr = list(initial_repr)
  39. self.max_epochs = int(num_epochs)
  40. self.batch_size = int(batch_size)
  41. self.loss = loss
  42. self.lr = float(lr)
  43. self.pos_data = model.data
  44. self.neg_data = negative_sample_data(model.data)
  45. self.pos_val_data = val_data
  46. self.neg_val_data = negative_sample_data(val_data)
  47. self.batcher = DualBatcher(self.pos_data, self.neg_data,
  48. batch_size=batch_size)
  49. self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
  50. self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  51. def run_epoch(self) -> None:
  52. loss_sum = 0.
  53. for pos_batch, neg_batch in self.batcher:
  54. batch = _merge_pos_neg_batches(pos_batch, neg_batch)
  55. self.opt.zero_grad()
  56. last_layer_repr = self.model.convolve(self.initial_repr)
  57. pred = self.model.decode(last_layer_repr, batch)
  58. loss = self.loss(pred, batch.target_values)
  59. loss.backward()
  60. self.opt.step()
  61. loss = loss.detach().cpu().item()
  62. loss_sum += loss
  63. print('loss:', loss)
  64. return loss_sum
  65. def validate_epoch(self):
  66. loss_sum = 0.
  67. for pos_batch, neg_batch in self.val_batcher:
  68. batch = _merge_pos_neg_batches(pos_batch, neg_batch)
  69. with torch.no_grad():
  70. last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
  71. pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
  72. loss = self.loss(pred, batch.target_values)
  73. loss = loss.detach().cpu().item()
  74. loss_sum += loss
  75. return loss_sum
  76. def run(self) -> None:
  77. best_loss = float('inf')
  78. epochs_without_improvement = 0
  79. for epoch in range(self.max_epochs):
  80. print('Epoch', epoch)
  81. loss_sum = self.run_epoch()
  82. print('train loss_sum:', loss_sum)
  83. loss_sum = self.validate_epoch()
  84. print('val loss_sum:', loss_sum)
  85. if loss_sum >= best_loss:
  86. epochs_without_improvement += 1
  87. else:
  88. epochs_without_improvement = 0
  89. best_loss = loss_sum
  90. if epochs_without_improvement == 2:
  91. print('Early stopping after epoch', epoch, 'due to no improvement')
  92. return (epoch, best_loss, loss_sum)