IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

106 lines
3.8KB

  1. from .model import Model
  2. from .batch import Batcher
  3. from .sampling import negative_sample_data
  4. from .data import Data
  5. def _merge_pos_neg_batches(pos_batch, neg_batch):
  6. assert len(pos_batch.edges) == len(neg_batch.edges)
  7. assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
  8. assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
  9. assert pos_batch.relation_type_index == neg_batch.relation_type_index
  10. batch = TrainingBatch(pos_batch.vertex_type_row,
  11. pos_batch.vertex_type_column,
  12. pos_batch.relation_type_index,
  13. torch.cat([ pos_batch.edges, neg_batch.edges ]),
  14. torch.cat([
  15. torch.ones(len(pos_batch.edges)),
  16. torch.zeros(len(neg_batch.edges))
  17. ]))
  18. return batch
  19. class TrainLoop(object):
  20. def __init__(self, model: Model,
  21. val_data: Data, test_data: Data,
  22. initial_repr: List[torch.Tensor],
  23. max_epochs: int = 50,
  24. batch_size: int = 512,
  25. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  26. torch.nn.functional.binary_cross_entropy_with_logits,
  27. lr: float = 0.001) -> None:
  28. assert isinstance(model, Model)
  29. assert isinstance(val_data, Data)
  30. assert isinstance(test_data, Data)
  31. assert callable(loss)
  32. self.model = model
  33. self.test_data = test_data
  34. self.initial_repr = list(initial_repr)
  35. self.max_epochs = int(num_epochs)
  36. self.batch_size = int(batch_size)
  37. self.loss = loss
  38. self.lr = float(lr)
  39. self.pos_data = model.data
  40. self.neg_data = negative_sample_data(model.data)
  41. self.pos_val_data = val_data
  42. self.neg_val_data = negative_sample_data(val_data)
  43. self.batcher = DualBatcher(self.pos_data, self.neg_data,
  44. batch_size=batch_size)
  45. self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
  46. self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  47. def run_epoch(self) -> None:
  48. loss_sum = 0.
  49. for pos_batch, neg_batch in self.batcher:
  50. batch = _merge_pos_neg_batches(pos_batch, neg_batch)
  51. self.opt.zero_grad()
  52. last_layer_repr = self.model.convolve(self.initial_repr)
  53. pred = self.model.decode(last_layer_repr, batch)
  54. loss = self.loss(pred, batch.target_values)
  55. loss.backward()
  56. self.opt.step()
  57. loss = loss.detach().cpu().item()
  58. loss_sum += loss
  59. print('loss:', loss)
  60. return loss_sum
  61. def validate_epoch(self):
  62. loss_sum = 0.
  63. for pos_batch, neg_batch in self.val_batcher:
  64. batch = _merge_pos_neg_batches(pos_batch, neg_batch)
  65. with torch.no_grad():
  66. last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
  67. pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
  68. loss = self.loss(pred, batch.target_values)
  69. loss = loss.detach().cpu().item()
  70. loss_sum += loss
  71. return loss_sum
  72. def run(self) -> None:
  73. best_loss = float('inf')
  74. epochs_without_improvement = 0
  75. for epoch in range(self.max_epochs):
  76. print('Epoch', epoch)
  77. loss_sum = self.run_epoch()
  78. print('train loss_sum:', loss_sum)
  79. loss_sum = self.validate_epoch()
  80. print('val loss_sum:', loss_sum)
  81. if loss_sum >= best_loss:
  82. epochs_without_improvement += 1
  83. else:
  84. epochs_without_improvement = 0
  85. best_loss = loss_sum
  86. if epochs_without_improvement == 2:
  87. print('Early stopping after epoch', epoch, 'due to no improvement')
  88. return (epoch, best_loss, loss_sum)