IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .weights import init_glorot
  7. from .dropout import dropout
  8. from typing import Tuple, \
  9. List
  10. def dedicom_decoder(input_dim: int, num_relation_types: int) -> \
  11. Tuple[torch.Tensor, List[torch.Tensor]]:
  12. global_interaction = init_glorot(input_dim, input_dim)
  13. local_variation = [
  14. torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
  15. for _ in range(num_relation_types)
  16. ]
  17. return (global_interaction, local_variation)
  18. def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \
  19. Tuple[torch.Tensor, List[torch.Tensor]]:
  20. global_interaction = torch.eye(input_dim, input_dim)
  21. local_variation = [
  22. torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
  23. for _ in range(num_relation_types)
  24. ]
  25. return (global_interaction, local_variation)
  26. def bilinear_decoder(input_dim: int, num_relation_types: int) -> \
  27. Tuple[torch.Tensor, List[torch.Tensor]]:
  28. global_interaction = torch.eye(input_dim, input_dim)
  29. local_variation = [
  30. init_glorot(input_dim, input_dim) \
  31. for _ in range(num_relation_types)
  32. ]
  33. return (global_interaction, local_variation)
  34. def inner_product_decoder(input_dim: int, num_relation_types: int) -> \
  35. Tuple[torch.Tensor, List[torch.Tensor]]:
  36. global_interaction = torch.eye(input_dim, input_dim)
  37. local_variation = torch.eye(input_dim, input_dim)
  38. local_variation = [ local_variation ] * num_relation_types
  39. return (global_interaction, local_variation)