IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

67 rindas
2.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .layer import Layer
  6. import torch
  7. from ..data import Data
  8. from typing import Type, \
  9. List, \
  10. Callable, \
  11. Union, \
  12. Dict, \
  13. Tuple
  14. from ..decode.cartesian import DEDICOMDecoder
  15. class DecodeLayer(torch.nn.Module):
  16. def __init__(self,
  17. data: Data,
  18. last_layer: Layer,
  19. keep_prob: float = 1.,
  20. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  21. decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, **kwargs) -> None:
  22. super().__init__(**kwargs)
  23. self.data = data
  24. self.last_layer = last_layer
  25. self.keep_prob = keep_prob
  26. self.activation = activation
  27. assert all([a == last_layer.output_dim[0] \
  28. for a in last_layer.output_dim])
  29. self.input_dim = last_layer.output_dim[0]
  30. self.output_dim = 1
  31. self.decoder_class = decoder_class
  32. self.decoders = None
  33. self.build()
  34. def build(self) -> None:
  35. self.decoders = {}
  36. for (node_type_row, node_type_col), rels in self.data.relation_types.items():
  37. key = (node_type_row, node_type_col)
  38. if isinstance(self.decoder_class, dict):
  39. if key in self.decoder_class:
  40. decoder_class = self.decoder_class[key]
  41. else:
  42. raise KeyError('Decoder not specified for edge type: %d -- %d' % key)
  43. else:
  44. decoder_class = self.decoder_class
  45. self.decoders[key] = decoder_class(self.input_dim,
  46. num_relation_types = len(rels),
  47. drop_prob = 1. - self.keep_prob,
  48. activation = self.activation)
  49. def forward(self, last_layer_repr: List[torch.Tensor]):
  50. res = {}
  51. for (node_type_row, node_type_col), rel in self.data.relation_types.items():
  52. key = (node_type_row, node_type_col)
  53. inputs_row = last_layer_repr[node_type_row]
  54. inputs_col = last_layer_repr[node_type_col]
  55. pred_adj_matrices = self.decoders[key](inputs_row, inputs_col)
  56. res[node_type_row, node_type_col] = pred_adj_matrices
  57. return res