|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import torch
- from .data import Data
- from .trainprep import PreparedData, \
- TrainValTest
- from typing import Type, \
- List, \
- Callable, \
- Union, \
- Dict, \
- Tuple
- from .decode import DEDICOMDecoder
- from dataclasses import dataclass
- import time
-
-
- @dataclass
- class RelationPredictions(object):
- edges_pos: TrainValTest
- edges_neg: TrainValTest
- edges_back_pos: TrainValTest
- edges_back_neg: TrainValTest
-
-
- @dataclass
- class RelationFamilyPredictions(object):
- relation_types: List[RelationPredictions]
-
-
- @dataclass
- class Predictions(object):
- relation_families: List[RelationFamilyPredictions]
-
-
- class DecodeLayer(torch.nn.Module):
- def __init__(self,
- input_dim: List[int],
- data: PreparedData,
- keep_prob: float = 1.,
- activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
- **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- if not isinstance(input_dim, list):
- raise TypeError('input_dim must be a List')
-
- if len(input_dim) != len(data.node_types):
- raise ValueError('input_dim must have length equal to num_node_types')
-
- if not all([ a == input_dim[0] for a in input_dim ]):
- raise ValueError('All elements of input_dim must have the same value')
-
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- self.input_dim = input_dim[0]
- self.output_dim = 1
- self.data = data
- self.keep_prob = keep_prob
- self.activation = activation
-
- self.decoders = None
- self.build()
-
- def build(self) -> None:
- self.decoders = torch.nn.ModuleList()
- for fam in self.data.relation_families:
- dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
- self.keep_prob, self.activation)
- self.decoders.append(dec)
-
- def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec):
- start_time = time.time()
- pred = []
- for p in edge_list_attr_names:
- tvt = []
- for t in ['train', 'val', 'test']:
- # print('r:', r)
- edges = getattr(getattr(r, p), t)
- inputs_row = last_layer_repr[row][edges[:, 0]]
- inputs_column = last_layer_repr[column][edges[:, 1]]
- tvt.append(dec(inputs_row, inputs_column, k))
- tvt = TrainValTest(*tvt)
- pred.append(tvt)
- print('DecodeLayer._get_tvt() took:', time.time() - start_time)
- return pred
-
- def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]:
- t = time.time()
- res = []
- for i, fam in enumerate(self.data.relation_families):
- fam_pred = []
- for k, r in enumerate(fam.relation_types):
- pred = []
- pred += self._get_tvt(r, ['edges_pos', 'edges_neg'],
- r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i])
- pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'],
- r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i])
- pred = RelationPredictions(*pred)
- fam_pred.append(pred)
- fam_pred = RelationFamilyPredictions(fam_pred)
- res.append(fam_pred)
- res = Predictions(res)
- print('DecodeLayer.forward() took', time.time() - t)
- return res
|