IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

model.py 2.6KB

4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
4 lat temu
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from .data import Data
  2. from typing import List, \
  3. Callable
  4. from .trainprep import PreparedData
  5. import torch
  6. from .convlayer import DecagonLayer
  7. from .input import OneHotInputLayer
  8. from types import FunctionType
  9. from .declayer import DecodeLayer
  10. from .batch import PredictionsBatch
  11. class Model(torch.nn.Module):
  12. def __init__(self, prep_d: PreparedData,
  13. layer_dimensions: List[int] = [32, 64],
  14. keep_prob: float = 1.,
  15. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  16. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  17. dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  18. **kwargs) -> None:
  19. super().__init__(**kwargs)
  20. if not isinstance(prep_d, PreparedData):
  21. raise TypeError('prep_d must be an instance of PreparedData')
  22. if not isinstance(layer_dimensions, list):
  23. raise TypeError('layer_dimensions must be a list')
  24. keep_prob = float(keep_prob)
  25. if not isinstance(rel_activation, FunctionType):
  26. raise TypeError('rel_activation must be a function')
  27. if not isinstance(layer_activation, FunctionType):
  28. raise TypeError('layer_activation must be a function')
  29. if not isinstance(dec_activation, FunctionType):
  30. raise TypeError('dec_activation must be a function')
  31. self.prep_d = prep_d
  32. self.layer_dimensions = layer_dimensions
  33. self.keep_prob = keep_prob
  34. self.rel_activation = rel_activation
  35. self.layer_activation = layer_activation
  36. self.dec_activation = dec_activation
  37. self.seq = None
  38. self.build()
  39. def build(self):
  40. in_layer = OneHotInputLayer(self.prep_d)
  41. last_output_dim = in_layer.output_dim
  42. seq = [ in_layer ]
  43. for dim in self.layer_dimensions:
  44. conv_layer = DecagonLayer(input_dim = last_output_dim,
  45. output_dim = [ dim ] * len(self.prep_d.node_types),
  46. data = self.prep_d,
  47. keep_prob = self.keep_prob,
  48. rel_activation = self.rel_activation,
  49. layer_activation = self.layer_activation)
  50. last_output_dim = conv_layer.output_dim
  51. seq.append(conv_layer)
  52. dec_layer = DecodeLayer(input_dim = last_output_dim,
  53. data = self.prep_d,
  54. keep_prob = self.keep_prob,
  55. activation = self.dec_activation)
  56. seq.append(dec_layer)
  57. seq = torch.nn.Sequential(*seq)
  58. self.seq = seq
  59. def forward(self, _):
  60. return self.seq(None)