| @@ -22,8 +22,8 @@ class DecodeLayer(torch.nn.Module): | |||||
| input_dim: List[int], | input_dim: List[int], | ||||
| data: Union[Data, PreparedData], | data: Union[Data, PreparedData], | ||||
| keep_prob: float = 1., | keep_prob: float = 1., | ||||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||||
| decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | ||||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||||
| **kwargs) -> None: | **kwargs) -> None: | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| @@ -35,9 +35,9 @@ class DecodeLayer(torch.nn.Module): | |||||
| self.output_dim = 1 | self.output_dim = 1 | ||||
| self.data = data | self.data = data | ||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| self.decoder_class = decoder_class | |||||
| self.activation = activation | self.activation = activation | ||||
| self.decoder_class = decoder_class | |||||
| self.decoders = None | self.decoders = None | ||||
| self.build() | self.build() | ||||
| @@ -45,6 +45,7 @@ class DecodeLayer(torch.nn.Module): | |||||
| self.decoders = {} | self.decoders = {} | ||||
| n = len(self.data.node_types) | n = len(self.data.node_types) | ||||
| relation_types = self.data.relation_types | |||||
| for node_type_row in range(n): | for node_type_row in range(n): | ||||
| if node_type_row not in relation_types: | if node_type_row not in relation_types: | ||||
| continue | continue | ||||
| @@ -70,34 +71,16 @@ class DecodeLayer(torch.nn.Module): | |||||
| decoder_class = self.decoder_class | decoder_class = self.decoder_class | ||||
| self.decoders[node_type_row, node_type_column] = \ | self.decoders[node_type_row, node_type_column] = \ | ||||
| decoder_class(self.input_dim, | |||||
| decoder_class(self.input_dim[node_type_row], | |||||
| num_relation_types = len(rels), | num_relation_types = len(rels), | ||||
| drop_prob = 1. - self.keep_prob, | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.activation) | activation = self.activation) | ||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest: | |||||
| # n = len(self.data.node_types) | |||||
| # relation_types = self.data.relation_types | |||||
| # for node_type_row in range(n): | |||||
| # if node_type_row not in relation_types: | |||||
| # continue | |||||
| # | |||||
| # for node_type_column in range(n): | |||||
| # if node_type_column not in relation_types[node_type_row]: | |||||
| # continue | |||||
| # | |||||
| # rels = relation_types[node_type_row][node_type_column] | |||||
| # | |||||
| # for mode in ['train', 'val', 'test']: | |||||
| # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode) | |||||
| # getattr(self.data.edges_neg, mode) | |||||
| # last_layer[] | |||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||||
| res = {} | res = {} | ||||
| for (node_type_row, node_type_column), dec in self.decoders.items(): | for (node_type_row, node_type_column), dec in self.decoders.items(): | ||||
| inputs_row = last_layer_repr[node_type_row] | inputs_row = last_layer_repr[node_type_row] | ||||
| inputs_column = last_layer_repr[node_type_column] | inputs_column = last_layer_repr[node_type_column] | ||||
| pred_adj_matrices = dec(inputs_row, inputs_col) | |||||
| res[node_type_row, node_type_col] = pred_adj_matrices | |||||
| pred_adj_matrices = dec(inputs_row, inputs_column) | |||||
| res[node_type_row, node_type_column] = pred_adj_matrices | |||||
| return res | return res | ||||