IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.

61 satır
2.2KB

  1. from .layer import Layer
  2. import torch
  3. from ..data import Data
  4. from typing import Type, \
  5. List, \
  6. Callable, \
  7. Union, \
  8. Dict, \
  9. Tuple
  10. from ..decode import DEDICOMDecoder
  11. class DecodeLayer(torch.nn.Module):
  12. def __init__(self,
  13. data: Data,
  14. last_layer: Layer,
  15. keep_prob: float = 1.,
  16. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  17. decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, **kwargs) -> None:
  18. super().__init__(**kwargs)
  19. self.data = data
  20. self.last_layer = last_layer
  21. self.keep_prob = keep_prob
  22. self.activation = activation
  23. assert all([a == last_layer.output_dim[0] \
  24. for a in last_layer.output_dim])
  25. self.input_dim = last_layer.output_dim[0]
  26. self.output_dim = 1
  27. self.decoder_class = decoder_class
  28. self.decoders = None
  29. self.build()
  30. def build(self) -> None:
  31. self.decoders = {}
  32. for (node_type_row, node_type_col), rels in self.data.relation_types.items():
  33. key = (node_type_row, node_type_col)
  34. if isinstance(self.decoder_class, dict):
  35. if key in self.decoder_class:
  36. decoder_class = self.decoder_class[key]
  37. else:
  38. raise KeyError('Decoder not specified for edge type: %d -- %d' % key)
  39. else:
  40. decoder_class = self.decoder_class
  41. self.decoders[key] = decoder_class(self.input_dim,
  42. num_relation_types = len(rels),
  43. drop_prob = 1. - self.keep_prob,
  44. activation = self.activation)
  45. def forward(self, last_layer_repr: List[torch.Tensor]):
  46. res = {}
  47. for (node_type_row, node_type_col), rel in self.data.relation_types.items():
  48. key = (node_type_row, node_type_col)
  49. inputs_row = last_layer_repr[node_type_row]
  50. inputs_col = last_layer_repr[node_type_col]
  51. pred_adj_matrices = self.decoders[key](inputs_row, inputs_col)
  52. res[node_type_row, node_type_col] = pred_adj_matrices
  53. return res