diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 78ae8b7..68dfdd1 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -22,8 +22,8 @@ class DecodeLayer(torch.nn.Module): input_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., - activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, + activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, **kwargs) -> None: super().__init__(**kwargs) @@ -35,9 +35,9 @@ class DecodeLayer(torch.nn.Module): self.output_dim = 1 self.data = data self.keep_prob = keep_prob + self.decoder_class = decoder_class self.activation = activation - self.decoder_class = decoder_class self.decoders = None self.build() @@ -45,6 +45,7 @@ class DecodeLayer(torch.nn.Module): self.decoders = {} n = len(self.data.node_types) + relation_types = self.data.relation_types for node_type_row in range(n): if node_type_row not in relation_types: continue @@ -70,34 +71,16 @@ class DecodeLayer(torch.nn.Module): decoder_class = self.decoder_class self.decoders[node_type_row, node_type_column] = \ - decoder_class(self.input_dim, + decoder_class(self.input_dim[node_type_row], num_relation_types = len(rels), - drop_prob = 1. - self.keep_prob, + keep_prob = self.keep_prob, activation = self.activation) - def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest: - - # n = len(self.data.node_types) - # relation_types = self.data.relation_types - # for node_type_row in range(n): - # if node_type_row not in relation_types: - # continue - # - # for node_type_column in range(n): - # if node_type_column not in relation_types[node_type_row]: - # continue - # - # rels = relation_types[node_type_row][node_type_column] - # - # for mode in ['train', 'val', 'test']: - # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode) - # getattr(self.data.edges_neg, mode) - # last_layer[] - + def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: res = {} for (node_type_row, node_type_column), dec in self.decoders.items(): inputs_row = last_layer_repr[node_type_row] inputs_column = last_layer_repr[node_type_column] - pred_adj_matrices = dec(inputs_row, inputs_col) - res[node_type_row, node_type_col] = pred_adj_matrices + pred_adj_matrices = dec(inputs_row, inputs_column) + res[node_type_row, node_type_column] = pred_adj_matrices return res