| @@ -22,8 +22,8 @@ class DecodeLayer(torch.nn.Module): | |||
| input_dim: List[int], | |||
| data: Union[Data, PreparedData], | |||
| keep_prob: float = 1., | |||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
| decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | |||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| @@ -35,9 +35,9 @@ class DecodeLayer(torch.nn.Module): | |||
| self.output_dim = 1 | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| self.decoder_class = decoder_class | |||
| self.activation = activation | |||
| self.decoder_class = decoder_class | |||
| self.decoders = None | |||
| self.build() | |||
| @@ -45,6 +45,7 @@ class DecodeLayer(torch.nn.Module): | |||
| self.decoders = {} | |||
| n = len(self.data.node_types) | |||
| relation_types = self.data.relation_types | |||
| for node_type_row in range(n): | |||
| if node_type_row not in relation_types: | |||
| continue | |||
| @@ -70,34 +71,16 @@ class DecodeLayer(torch.nn.Module): | |||
| decoder_class = self.decoder_class | |||
| self.decoders[node_type_row, node_type_column] = \ | |||
| decoder_class(self.input_dim, | |||
| decoder_class(self.input_dim[node_type_row], | |||
| num_relation_types = len(rels), | |||
| drop_prob = 1. - self.keep_prob, | |||
| keep_prob = self.keep_prob, | |||
| activation = self.activation) | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest: | |||
| # n = len(self.data.node_types) | |||
| # relation_types = self.data.relation_types | |||
| # for node_type_row in range(n): | |||
| # if node_type_row not in relation_types: | |||
| # continue | |||
| # | |||
| # for node_type_column in range(n): | |||
| # if node_type_column not in relation_types[node_type_row]: | |||
| # continue | |||
| # | |||
| # rels = relation_types[node_type_row][node_type_column] | |||
| # | |||
| # for mode in ['train', 'val', 'test']: | |||
| # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode) | |||
| # getattr(self.data.edges_neg, mode) | |||
| # last_layer[] | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||
| res = {} | |||
| for (node_type_row, node_type_column), dec in self.decoders.items(): | |||
| inputs_row = last_layer_repr[node_type_row] | |||
| inputs_column = last_layer_repr[node_type_column] | |||
| pred_adj_matrices = dec(inputs_row, inputs_col) | |||
| res[node_type_row, node_type_col] = pred_adj_matrices | |||
| pred_adj_matrices = dec(inputs_row, inputs_column) | |||
| res[node_type_row, node_type_column] = pred_adj_matrices | |||
| return res | |||