diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 46bffca..78ae8b7 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -1,2 +1,103 @@ -# from .layer import DecagonLayer -# from .input import OneHotInputLayer +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .data import Data +from .trainprep import PreparedData, \ + TrainValTest +from typing import Type, \ + List, \ + Callable, \ + Union, \ + Dict, \ + Tuple +from .decode import DEDICOMDecoder + + +class DecodeLayer(torch.nn.Module): + def __init__(self, + input_dim: List[int], + data: Union[Data, PreparedData], + keep_prob: float = 1., + activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, + decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, + **kwargs) -> None: + + super().__init__(**kwargs) + + assert all([ a == input_dim[0] \ + for a in input_dim ]) + + self.input_dim = input_dim + self.output_dim = 1 + self.data = data + self.keep_prob = keep_prob + self.activation = activation + + self.decoder_class = decoder_class + self.decoders = None + self.build() + + def build(self) -> None: + self.decoders = {} + + n = len(self.data.node_types) + for node_type_row in range(n): + if node_type_row not in relation_types: + continue + + for node_type_column in range(n): + if node_type_column not in relation_types[node_type_row]: + continue + + rels = relation_types[node_type_row][node_type_column] + if len(rels) == 0: + continue + + if isinstance(self.decoder_class, dict): + if (node_type_row, node_type_column) in self.decoder_class: + decoder_class = self.decoder_class[node_type_row, node_type_column] + elif (node_type_column, node_type_row) in self.decoder_class: + decoder_class = self.decoder_class[node_type_column, node_type_row] + else: + raise KeyError('Decoder not specified for edge type: %s -- %s' % ( + self.data.node_types[node_type_row].name, + self.data.node_types[node_type_column].name)) + else: + decoder_class = self.decoder_class + + self.decoders[node_type_row, node_type_column] = \ + decoder_class(self.input_dim, + num_relation_types = len(rels), + drop_prob = 1. - self.keep_prob, + activation = self.activation) + + def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest: + + # n = len(self.data.node_types) + # relation_types = self.data.relation_types + # for node_type_row in range(n): + # if node_type_row not in relation_types: + # continue + # + # for node_type_column in range(n): + # if node_type_column not in relation_types[node_type_row]: + # continue + # + # rels = relation_types[node_type_row][node_type_column] + # + # for mode in ['train', 'val', 'test']: + # getattr(relation_types[node_type_row][node_type_column].edges_pos, mode) + # getattr(self.data.edges_neg, mode) + # last_layer[] + + res = {} + for (node_type_row, node_type_column), dec in self.decoders.items(): + inputs_row = last_layer_repr[node_type_row] + inputs_column = last_layer_repr[node_type_column] + pred_adj_matrices = dec(inputs_row, inputs_col) + res[node_type_row, node_type_col] = pred_adj_matrices + return res diff --git a/tests/icosagon/test_sampling.py b/tests/icosagon/test_sampling.py index 3bc3327..8552949 100644 --- a/tests/icosagon/test_sampling.py +++ b/tests/icosagon/test_sampling.py @@ -132,7 +132,7 @@ def test_unigram_03(): counts_tf = defaultdict(list) counts_torch = defaultdict(list) - for i in range(100): + for i in range(10): neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler( true_classes=true_classes_tf, num_true=num_true, diff --git a/tests/icosagon/test_weights.py b/tests/icosagon/test_weights.py index 7456076..5ddb997 100644 --- a/tests/icosagon/test_weights.py +++ b/tests/icosagon/test_weights.py @@ -11,3 +11,13 @@ def test_init_glorot_01(): init_range = np.sqrt(6.0 / 30) expected = -init_range + 2 * init_range * rnd assert torch.all(res == expected) + + +def test_init_glorot_02(): + torch.random.manual_seed(0) + res = init_glorot(20, 10) + torch.random.manual_seed(0) + rnd = torch.rand((20, 10)) + init_range = np.sqrt(6.0 / 30) + expected = -init_range + 2 * init_range * rnd + assert torch.all(res == expected)