|
|
@@ -1,2 +1,103 @@ |
|
|
|
# from .layer import DecagonLayer
|
|
|
|
# from .input import OneHotInputLayer
|
|
|
|
#
|
|
|
|
# Copyright (C) Stanislaw Adaszewski, 2020
|
|
|
|
# License: GPLv3
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from .data import Data
|
|
|
|
from .trainprep import PreparedData, \
|
|
|
|
TrainValTest
|
|
|
|
from typing import Type, \
|
|
|
|
List, \
|
|
|
|
Callable, \
|
|
|
|
Union, \
|
|
|
|
Dict, \
|
|
|
|
Tuple
|
|
|
|
from .decode import DEDICOMDecoder
|
|
|
|
|
|
|
|
|
|
|
|
class DecodeLayer(torch.nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
input_dim: List[int],
|
|
|
|
data: Union[Data, PreparedData],
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
|
|
|
|
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
assert all([ a == input_dim[0] \
|
|
|
|
for a in input_dim ])
|
|
|
|
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = 1
|
|
|
|
self.data = data
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.decoder_class = decoder_class
|
|
|
|
self.decoders = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
|
|
self.decoders = {}
|
|
|
|
|
|
|
|
n = len(self.data.node_types)
|
|
|
|
for node_type_row in range(n):
|
|
|
|
if node_type_row not in relation_types:
|
|
|
|
continue
|
|
|
|
|
|
|
|
for node_type_column in range(n):
|
|
|
|
if node_type_column not in relation_types[node_type_row]:
|
|
|
|
continue
|
|
|
|
|
|
|
|
rels = relation_types[node_type_row][node_type_column]
|
|
|
|
if len(rels) == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if isinstance(self.decoder_class, dict):
|
|
|
|
if (node_type_row, node_type_column) in self.decoder_class:
|
|
|
|
decoder_class = self.decoder_class[node_type_row, node_type_column]
|
|
|
|
elif (node_type_column, node_type_row) in self.decoder_class:
|
|
|
|
decoder_class = self.decoder_class[node_type_column, node_type_row]
|
|
|
|
else:
|
|
|
|
raise KeyError('Decoder not specified for edge type: %s -- %s' % (
|
|
|
|
self.data.node_types[node_type_row].name,
|
|
|
|
self.data.node_types[node_type_column].name))
|
|
|
|
else:
|
|
|
|
decoder_class = self.decoder_class
|
|
|
|
|
|
|
|
self.decoders[node_type_row, node_type_column] = \
|
|
|
|
decoder_class(self.input_dim,
|
|
|
|
num_relation_types = len(rels),
|
|
|
|
drop_prob = 1. - self.keep_prob,
|
|
|
|
activation = self.activation)
|
|
|
|
|
|
|
|
def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest:
|
|
|
|
|
|
|
|
# n = len(self.data.node_types)
|
|
|
|
# relation_types = self.data.relation_types
|
|
|
|
# for node_type_row in range(n):
|
|
|
|
# if node_type_row not in relation_types:
|
|
|
|
# continue
|
|
|
|
#
|
|
|
|
# for node_type_column in range(n):
|
|
|
|
# if node_type_column not in relation_types[node_type_row]:
|
|
|
|
# continue
|
|
|
|
#
|
|
|
|
# rels = relation_types[node_type_row][node_type_column]
|
|
|
|
#
|
|
|
|
# for mode in ['train', 'val', 'test']:
|
|
|
|
# getattr(relation_types[node_type_row][node_type_column].edges_pos, mode)
|
|
|
|
# getattr(self.data.edges_neg, mode)
|
|
|
|
# last_layer[]
|
|
|
|
|
|
|
|
res = {}
|
|
|
|
for (node_type_row, node_type_column), dec in self.decoders.items():
|
|
|
|
inputs_row = last_layer_repr[node_type_row]
|
|
|
|
inputs_column = last_layer_repr[node_type_column]
|
|
|
|
pred_adj_matrices = dec(inputs_row, inputs_col)
|
|
|
|
res[node_type_row, node_type_col] = pred_adj_matrices
|
|
|
|
return res
|