# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # import torch from .data import Data from .trainprep import PreparedData, \ TrainValTest from typing import Type, \ List, \ Callable, \ Union, \ Dict, \ Tuple from .decode import DEDICOMDecoder class DecodeLayer(torch.nn.Module): def __init__(self, input_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, **kwargs) -> None: super().__init__(**kwargs) assert all([ a == input_dim[0] \ for a in input_dim ]) self.input_dim = input_dim self.output_dim = 1 self.data = data self.keep_prob = keep_prob self.activation = activation self.decoder_class = decoder_class self.decoders = None self.build() def build(self) -> None: self.decoders = {} n = len(self.data.node_types) for node_type_row in range(n): if node_type_row not in relation_types: continue for node_type_column in range(n): if node_type_column not in relation_types[node_type_row]: continue rels = relation_types[node_type_row][node_type_column] if len(rels) == 0: continue if isinstance(self.decoder_class, dict): if (node_type_row, node_type_column) in self.decoder_class: decoder_class = self.decoder_class[node_type_row, node_type_column] elif (node_type_column, node_type_row) in self.decoder_class: decoder_class = self.decoder_class[node_type_column, node_type_row] else: raise KeyError('Decoder not specified for edge type: %s -- %s' % ( self.data.node_types[node_type_row].name, self.data.node_types[node_type_column].name)) else: decoder_class = self.decoder_class self.decoders[node_type_row, node_type_column] = \ decoder_class(self.input_dim, num_relation_types = len(rels), drop_prob = 1. - self.keep_prob, activation = self.activation) def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest: # n = len(self.data.node_types) # relation_types = self.data.relation_types # for node_type_row in range(n): # if node_type_row not in relation_types: # continue # # for node_type_column in range(n): # if node_type_column not in relation_types[node_type_row]: # continue # # rels = relation_types[node_type_row][node_type_column] # # for mode in ['train', 'val', 'test']: # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode) # getattr(self.data.edges_neg, mode) # last_layer[] res = {} for (node_type_row, node_type_column), dec in self.decoders.items(): inputs_row = last_layer_repr[node_type_row] inputs_column = last_layer_repr[node_type_column] pred_adj_matrices = dec(inputs_row, inputs_col) res[node_type_row, node_type_col] = pred_adj_matrices return res