# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # import torch from .data import Data from .trainprep import PreparedData, \ TrainValTest from typing import Type, \ List, \ Callable, \ Union, \ Dict, \ Tuple from .decode import DEDICOMDecoder from dataclasses import dataclass import time from .databatch import BatchedDataPointer @dataclass class RelationPredictions(object): edges_pos: TrainValTest edges_neg: TrainValTest edges_back_pos: TrainValTest edges_back_neg: TrainValTest @dataclass class RelationFamilyPredictions(object): relation_types: List[RelationPredictions] @dataclass class Predictions(object): relation_families: List[RelationFamilyPredictions] class DecodeLayer(torch.nn.Module): def __init__(self, input_dim: List[int], data: PreparedData, keep_prob: float = 1., activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, batched_data_pointer: BatchedDataPointer = None, **kwargs) -> None: super().__init__(**kwargs) if not isinstance(input_dim, list): raise TypeError('input_dim must be a List') if len(input_dim) != len(data.node_types): raise ValueError('input_dim must have length equal to num_node_types') if not all([ a == input_dim[0] for a in input_dim ]): raise ValueError('All elements of input_dim must have the same value') if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') if batched_data_pointer is not None and \ not isinstance(batched_data_pointer, BatchedDataPointer): raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer') # if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data): # raise ValueError('batched_data_pointer must be compatible with data') self.input_dim = input_dim[0] self.output_dim = 1 self.data = data self.keep_prob = keep_prob self.activation = activation self.batched_data_pointer = batched_data_pointer self.decoders = None self.build() def build(self) -> None: self.decoders = torch.nn.ModuleList() for fam in self.data.relation_families: dec = fam.decoder_class(self.input_dim, len(fam.relation_types), self.keep_prob, self.activation) self.decoders.append(dec) def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): start_time = time.time() pred = [] for p in edge_list_attr_names: tvt = [] for t in ['train', 'val', 'test']: # print('r:', r) edges = getattr(getattr(r, p), t) inputs_row = last_layer_repr[row][edges[:, 0]] inputs_column = last_layer_repr[column][edges[:, 1]] tvt.append(dec(inputs_row, inputs_column, k)) tvt = TrainValTest(*tvt) pred.append(tvt) # print('DecodeLayer._get_tvt() took:', time.time() - start_time) return pred def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: t = time.time() res = [] data = self.batched_data_pointer.batched_data \ if self.batched_data_pointer is not None \ else self.data for i, fam in enumerate(data.relation_families): fam_pred = [] for k, r in enumerate(fam.relation_types): pred = [] pred += self._get_tvt(r, ['edges_pos', 'edges_neg'], r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i]) pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'], r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i]) pred = RelationPredictions(*pred) fam_pred.append(pred) fam_pred = RelationFamilyPredictions(fam_pred) res.append(fam_pred) res = Predictions(res) # print('DecodeLayer.forward() took', time.time() - t) return res