IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
4.5KB

  1. from .data import Data
  2. from typing import List, \
  3. Callable
  4. from .trainprep import prepare_training, \
  5. TrainValTest
  6. import torch
  7. from .convlayer import DecagonLayer
  8. from .input import OneHotInputLayer
  9. from types import FunctionType
  10. from .declayer import DecodeLayer
  11. from .batch import PredictionsBatch
  12. class Model(object):
  13. def __init__(self, data: Data,
  14. layer_dimensions: List[int] = [32, 64],
  15. ratios: TrainValTest = TrainValTest(.8, .1, .1),
  16. keep_prob: float = 1.,
  17. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  18. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  19. dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  20. lr: float = 0.001,
  21. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
  22. batch_size: int = 100) -> None:
  23. if not isinstance(data, Data):
  24. raise TypeError('data must be an instance of Data')
  25. if not isinstance(layer_dimensions, list):
  26. raise TypeError('layer_dimensions must be a list')
  27. if not isinstance(ratios, TrainValTest):
  28. raise TypeError('ratios must be an instance of TrainValTest')
  29. keep_prob = float(keep_prob)
  30. if not isinstance(rel_activation, FunctionType):
  31. raise TypeError('rel_activation must be a function')
  32. if not isinstance(layer_activation, FunctionType):
  33. raise TypeError('layer_activation must be a function')
  34. if not isinstance(dec_activation, FunctionType):
  35. raise TypeError('dec_activation must be a function')
  36. lr = float(lr)
  37. if not isinstance(loss, FunctionType):
  38. raise TypeError('loss must be a function')
  39. batch_size = int(batch_size)
  40. self.data = data
  41. self.layer_dimensions = layer_dimensions
  42. self.ratios = ratios
  43. self.keep_prob = keep_prob
  44. self.rel_activation = rel_activation
  45. self.layer_activation = layer_activation
  46. self.dec_activation = dec_activation
  47. self.lr = lr
  48. self.loss = loss
  49. self.batch_size = batch_size
  50. self.prep_d = None
  51. self.seq = None
  52. self.opt = None
  53. self.build()
  54. def build(self):
  55. self.prep_d = prepare_training(self.data, self.ratios)
  56. in_layer = OneHotInputLayer(self.prep_d)
  57. last_output_dim = in_layer.output_dim
  58. seq = [ in_layer ]
  59. for dim in self.layer_dimensions:
  60. conv_layer = DecagonLayer(input_dim = last_output_dim,
  61. output_dim = [ dim ] * len(self.prep_d.node_types),
  62. data = self.prep_d,
  63. keep_prob = self.keep_prob,
  64. rel_activation = self.rel_activation,
  65. layer_activation = self.layer_activation)
  66. last_output_dim = conv_layer.output_dim
  67. seq.append(conv_layer)
  68. dec_layer = DecodeLayer(input_dim = last_output_dim,
  69. data = self.prep_d,
  70. keep_prob = self.keep_prob,
  71. activation = self.dec_activation)
  72. seq.append(dec_layer)
  73. seq = torch.nn.Sequential(*seq)
  74. self.seq = seq
  75. opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
  76. self.opt = opt
  77. def run_epoch(self):
  78. pred = self.seq(None)
  79. batch = PredictionsBatch(pred, batch_size=self.batch_size)
  80. n = len(list(iter(batch)))
  81. loss_sum = 0
  82. for i in range(n):
  83. self.opt.zero_grad()
  84. pred = self.seq(None)
  85. batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  86. seed = torch.rand(1).item()
  87. rng_state = torch.get_rng_state()
  88. torch.manual_seed(seed)
  89. it = iter(batch)
  90. torch.set_rng_state(rng_state)
  91. for k in range(i):
  92. _ = next(it)
  93. (input, target) = next(it)
  94. loss = self.loss(input, target)
  95. loss.backward()
  96. self.opt.step()
  97. loss_sum += loss.detach().cpu().item()
  98. return loss_sum
  99. def train(self, max_epochs):
  100. best_loss = None
  101. best_epoch = None
  102. for i in range(max_epochs):
  103. loss = self.run_epoch()
  104. if best_loss is None or loss < best_loss:
  105. best_loss = loss
  106. best_epoch = i
  107. return loss, best_loss, best_epoch