|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import torch
- from .data import Data
- from .trainprep import PreparedData, \
- TrainValTest
- from typing import Type, \
- List, \
- Callable, \
- Union, \
- Dict, \
- Tuple
- from .decode import DEDICOMDecoder
-
-
- class DecodeLayer(torch.nn.Module):
- def __init__(self,
- input_dim: List[int],
- data: Union[Data, PreparedData],
- keep_prob: float = 1.,
- activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
- decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
- **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- assert all([ a == input_dim[0] \
- for a in input_dim ])
-
- self.input_dim = input_dim
- self.output_dim = 1
- self.data = data
- self.keep_prob = keep_prob
- self.activation = activation
-
- self.decoder_class = decoder_class
- self.decoders = None
- self.build()
-
- def build(self) -> None:
- self.decoders = {}
-
- n = len(self.data.node_types)
- for node_type_row in range(n):
- if node_type_row not in relation_types:
- continue
-
- for node_type_column in range(n):
- if node_type_column not in relation_types[node_type_row]:
- continue
-
- rels = relation_types[node_type_row][node_type_column]
- if len(rels) == 0:
- continue
-
- if isinstance(self.decoder_class, dict):
- if (node_type_row, node_type_column) in self.decoder_class:
- decoder_class = self.decoder_class[node_type_row, node_type_column]
- elif (node_type_column, node_type_row) in self.decoder_class:
- decoder_class = self.decoder_class[node_type_column, node_type_row]
- else:
- raise KeyError('Decoder not specified for edge type: %s -- %s' % (
- self.data.node_types[node_type_row].name,
- self.data.node_types[node_type_column].name))
- else:
- decoder_class = self.decoder_class
-
- self.decoders[node_type_row, node_type_column] = \
- decoder_class(self.input_dim,
- num_relation_types = len(rels),
- drop_prob = 1. - self.keep_prob,
- activation = self.activation)
-
- def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest:
-
- # n = len(self.data.node_types)
- # relation_types = self.data.relation_types
- # for node_type_row in range(n):
- # if node_type_row not in relation_types:
- # continue
- #
- # for node_type_column in range(n):
- # if node_type_column not in relation_types[node_type_row]:
- # continue
- #
- # rels = relation_types[node_type_row][node_type_column]
- #
- # for mode in ['train', 'val', 'test']:
- # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode)
- # getattr(self.data.edges_neg, mode)
- # last_layer[]
-
- res = {}
- for (node_type_row, node_type_column), dec in self.decoders.items():
- inputs_row = last_layer_repr[node_type_row]
- inputs_column = last_layer_repr[node_type_column]
- pred_adj_matrices = dec(inputs_row, inputs_col)
- res[node_type_row, node_type_col] = pred_adj_matrices
- return res
|