IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
3.7KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .data import Data
  7. from .trainprep import PreparedData, \
  8. TrainValTest
  9. from typing import Type, \
  10. List, \
  11. Callable, \
  12. Union, \
  13. Dict, \
  14. Tuple
  15. from .decode import DEDICOMDecoder
  16. class DecodeLayer(torch.nn.Module):
  17. def __init__(self,
  18. input_dim: List[int],
  19. data: Union[Data, PreparedData],
  20. keep_prob: float = 1.,
  21. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  22. decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
  23. **kwargs) -> None:
  24. super().__init__(**kwargs)
  25. assert all([ a == input_dim[0] \
  26. for a in input_dim ])
  27. self.input_dim = input_dim
  28. self.output_dim = 1
  29. self.data = data
  30. self.keep_prob = keep_prob
  31. self.activation = activation
  32. self.decoder_class = decoder_class
  33. self.decoders = None
  34. self.build()
  35. def build(self) -> None:
  36. self.decoders = {}
  37. n = len(self.data.node_types)
  38. for node_type_row in range(n):
  39. if node_type_row not in relation_types:
  40. continue
  41. for node_type_column in range(n):
  42. if node_type_column not in relation_types[node_type_row]:
  43. continue
  44. rels = relation_types[node_type_row][node_type_column]
  45. if len(rels) == 0:
  46. continue
  47. if isinstance(self.decoder_class, dict):
  48. if (node_type_row, node_type_column) in self.decoder_class:
  49. decoder_class = self.decoder_class[node_type_row, node_type_column]
  50. elif (node_type_column, node_type_row) in self.decoder_class:
  51. decoder_class = self.decoder_class[node_type_column, node_type_row]
  52. else:
  53. raise KeyError('Decoder not specified for edge type: %s -- %s' % (
  54. self.data.node_types[node_type_row].name,
  55. self.data.node_types[node_type_column].name))
  56. else:
  57. decoder_class = self.decoder_class
  58. self.decoders[node_type_row, node_type_column] = \
  59. decoder_class(self.input_dim,
  60. num_relation_types = len(rels),
  61. drop_prob = 1. - self.keep_prob,
  62. activation = self.activation)
  63. def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest:
  64. # n = len(self.data.node_types)
  65. # relation_types = self.data.relation_types
  66. # for node_type_row in range(n):
  67. # if node_type_row not in relation_types:
  68. # continue
  69. #
  70. # for node_type_column in range(n):
  71. # if node_type_column not in relation_types[node_type_row]:
  72. # continue
  73. #
  74. # rels = relation_types[node_type_row][node_type_column]
  75. #
  76. # for mode in ['train', 'val', 'test']:
  77. # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode)
  78. # getattr(self.data.edges_neg, mode)
  79. # last_layer[]
  80. res = {}
  81. for (node_type_row, node_type_column), dec in self.decoders.items():
  82. inputs_row = last_layer_repr[node_type_row]
  83. inputs_column = last_layer_repr[node_type_column]
  84. pred_adj_matrices = dec(inputs_row, inputs_col)
  85. res[node_type_row, node_type_col] = pred_adj_matrices
  86. return res