IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

91 строка
3.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .data import Data
  7. from .trainprep import PreparedData, \
  8. TrainValTest
  9. from typing import Type, \
  10. List, \
  11. Callable, \
  12. Union, \
  13. Dict, \
  14. Tuple
  15. from .decode import DEDICOMDecoder
  16. class DecodeLayer(torch.nn.Module):
  17. def __init__(self,
  18. input_dim: List[int],
  19. data: PreparedData,
  20. keep_prob: float = 1.,
  21. decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
  22. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  23. **kwargs) -> None:
  24. super().__init__(**kwargs)
  25. if not isinstance(input_dim, list):
  26. raise TypeError('input_dim must be a List')
  27. if not all([ a == input_dim[0] for a in input_dim ]):
  28. raise ValueError('All elements of input_dim must have the same value')
  29. if not isinstance(data, PreparedData):
  30. raise TypeError('data must be an instance of PreparedData')
  31. if not isinstance(decoder_class, type) and \
  32. not isinstance(decoder_class, dict):
  33. raise TypeError('decoder_class must be a Type or a Dict')
  34. if not isinstance(decoder_class, dict):
  35. decoder_class = { k: decoder_class \
  36. for k in data.relation_types.keys() }
  37. self.input_dim = input_dim
  38. self.output_dim = 1
  39. self.data = data
  40. self.keep_prob = keep_prob
  41. self.decoder_class = decoder_class
  42. self.activation = activation
  43. self.decoders = None
  44. self.build()
  45. def build(self) -> None:
  46. self.decoders = {}
  47. for (node_type_row, node_type_column), rels in self.data.relation_types.items():
  48. if len(rels) == 0:
  49. continue
  50. if isinstance(self.decoder_class, dict):
  51. if (node_type_row, node_type_column) in self.decoder_class:
  52. decoder_class = self.decoder_class[node_type_row, node_type_column]
  53. elif (node_type_column, node_type_row) in self.decoder_class:
  54. decoder_class = self.decoder_class[node_type_column, node_type_row]
  55. else:
  56. raise KeyError('Decoder not specified for edge type: %s -- %s' % (
  57. self.data.node_types[node_type_row].name,
  58. self.data.node_types[node_type_column].name))
  59. else:
  60. decoder_class = self.decoder_class
  61. self.decoders[node_type_row, node_type_column] = \
  62. decoder_class(self.input_dim[node_type_row],
  63. num_relation_types = len(rels),
  64. keep_prob = self.keep_prob,
  65. activation = self.activation)
  66. def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]:
  67. res = {}
  68. for (node_type_row, node_type_column), dec in self.decoders.items():
  69. inputs_row = last_layer_repr[node_type_row]
  70. inputs_column = last_layer_repr[node_type_column]
  71. pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ]
  72. res[node_type_row, node_type_column] = pred_adj_matrices
  73. return res