IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.2KB

  1. from .data import Data
  2. from typing import List, \
  3. Callable
  4. from .trainprep import PreparedData
  5. import torch
  6. from .convlayer import DecagonLayer
  7. from .input import OneHotInputLayer
  8. from types import FunctionType
  9. from .declayer import DecodeLayer
  10. from .batch import PredictionsBatch
  11. class Model(object):
  12. def __init__(self, prep_d: PreparedData,
  13. layer_dimensions: List[int] = [32, 64],
  14. keep_prob: float = 1.,
  15. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  16. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  17. dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  18. lr: float = 0.001,
  19. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.binary_cross_entropy_with_logits,
  20. batch_size: int = 100) -> None:
  21. if not isinstance(prep_d, PreparedData):
  22. raise TypeError('prep_d must be an instance of PreparedData')
  23. if not isinstance(layer_dimensions, list):
  24. raise TypeError('layer_dimensions must be a list')
  25. keep_prob = float(keep_prob)
  26. if not isinstance(rel_activation, FunctionType):
  27. raise TypeError('rel_activation must be a function')
  28. if not isinstance(layer_activation, FunctionType):
  29. raise TypeError('layer_activation must be a function')
  30. if not isinstance(dec_activation, FunctionType):
  31. raise TypeError('dec_activation must be a function')
  32. lr = float(lr)
  33. if not isinstance(loss, FunctionType):
  34. raise TypeError('loss must be a function')
  35. batch_size = int(batch_size)
  36. self.prep_d = prep_d
  37. self.layer_dimensions = layer_dimensions
  38. self.keep_prob = keep_prob
  39. self.rel_activation = rel_activation
  40. self.layer_activation = layer_activation
  41. self.dec_activation = dec_activation
  42. self.lr = lr
  43. self.loss = loss
  44. self.batch_size = batch_size
  45. self.seq = None
  46. self.opt = None
  47. self.build()
  48. def build(self):
  49. in_layer = OneHotInputLayer(self.prep_d)
  50. last_output_dim = in_layer.output_dim
  51. seq = [ in_layer ]
  52. for dim in self.layer_dimensions:
  53. conv_layer = DecagonLayer(input_dim = last_output_dim,
  54. output_dim = [ dim ] * len(self.prep_d.node_types),
  55. data = self.prep_d,
  56. keep_prob = self.keep_prob,
  57. rel_activation = self.rel_activation,
  58. layer_activation = self.layer_activation)
  59. last_output_dim = conv_layer.output_dim
  60. seq.append(conv_layer)
  61. dec_layer = DecodeLayer(input_dim = last_output_dim,
  62. data = self.prep_d,
  63. keep_prob = self.keep_prob,
  64. activation = self.dec_activation)
  65. seq.append(dec_layer)
  66. seq = torch.nn.Sequential(*seq)
  67. self.seq = seq
  68. opt = torch.optim.Adam(seq.parameters(), lr=self.lr)
  69. self.opt = opt
  70. def run_epoch(self):
  71. pred = self.seq(None)
  72. batch = PredictionsBatch(pred, batch_size=self.batch_size)
  73. n = len(list(iter(batch)))
  74. loss_sum = 0
  75. for i in range(n):
  76. self.opt.zero_grad()
  77. pred = self.seq(None)
  78. batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
  79. seed = torch.rand(1).item()
  80. rng_state = torch.get_rng_state()
  81. torch.manual_seed(seed)
  82. it = iter(batch)
  83. torch.set_rng_state(rng_state)
  84. for k in range(i):
  85. _ = next(it)
  86. (input, target) = next(it)
  87. loss = self.loss(input, target)
  88. loss.backward()
  89. self.opt.step()
  90. loss_sum += loss.detach().cpu().item()
  91. return loss_sum
  92. def train(self, max_epochs):
  93. best_loss = None
  94. best_epoch = None
  95. for i in range(max_epochs):
  96. loss = self.run_epoch()
  97. if best_loss is None or loss < best_loss:
  98. best_loss = loss
  99. best_epoch = i
  100. return loss, best_loss, best_epoch