|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import torch
- from .weights import init_glorot
- from .dropout import dropout
- from typing import Tuple, \
- List
-
-
- def dedicom_decoder(input_dim: int, num_relation_types: int) -> \
- Tuple[torch.Tensor, List[torch.Tensor]]:
-
- global_interaction = init_glorot(input_dim, input_dim)
- local_variation = [
- torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
- for _ in range(num_relation_types)
- ]
- return (global_interaction, local_variation)
-
-
- def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \
- Tuple[torch.Tensor, List[torch.Tensor]]:
-
- global_interaction = torch.eye(input_dim, input_dim)
- local_variation = [
- torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
- for _ in range(num_relation_types)
- ]
- return (global_interaction, local_variation)
-
-
- def bilinear_decoder(input_dim: int, num_relation_types: int) -> \
- Tuple[torch.Tensor, List[torch.Tensor]]:
-
- global_interaction = torch.eye(input_dim, input_dim)
- local_variation = [
- init_glorot(input_dim, input_dim) \
- for _ in range(num_relation_types)
- ]
- return (global_interaction, local_variation)
-
-
- def inner_product_decoder(input_dim: int, num_relation_types: int) -> \
- Tuple[torch.Tensor, List[torch.Tensor]]:
-
- global_interaction = torch.eye(input_dim, input_dim)
- local_variation = torch.eye(input_dim, input_dim)
- local_variation = [ local_variation ] * num_relation_types
- return (global_interaction, local_variation)
|