IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.6KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .weights import init_glorot
  7. from .dropout import dropout
  8. class DEDICOMDecoder(torch.nn.Module):
  9. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  10. def __init__(self, input_dim, num_relation_types, keep_prob=1.,
  11. activation=torch.sigmoid, **kwargs):
  12. super().__init__(**kwargs)
  13. self.input_dim = input_dim
  14. self.num_relation_types = num_relation_types
  15. self.keep_prob = keep_prob
  16. self.activation = activation
  17. self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
  18. self.local_variation = torch.nn.ParameterList([
  19. torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
  20. for _ in range(num_relation_types)
  21. ])
  22. def forward(self, inputs_row, inputs_col, relation_index):
  23. inputs_row = dropout(inputs_row, self.keep_prob)
  24. inputs_col = dropout(inputs_col, self.keep_prob)
  25. relation = torch.diag(self.local_variation[relation_index])
  26. product1 = torch.mm(inputs_row, relation)
  27. product2 = torch.mm(product1, self.global_interaction)
  28. product3 = torch.mm(product2, relation)
  29. rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
  30. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  31. rec = torch.flatten(rec)
  32. return self.activation(rec)
  33. class DistMultDecoder(torch.nn.Module):
  34. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  35. def __init__(self, input_dim, num_relation_types, keep_prob=1.,
  36. activation=torch.sigmoid, **kwargs):
  37. super().__init__(**kwargs)
  38. self.input_dim = input_dim
  39. self.num_relation_types = num_relation_types
  40. self.keep_prob = keep_prob
  41. self.activation = activation
  42. self.relation = torch.nn.ParameterList([
  43. torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
  44. for _ in range(num_relation_types)
  45. ])
  46. def forward(self, inputs_row, inputs_col, relation_index):
  47. inputs_row = dropout(inputs_row, self.keep_prob)
  48. inputs_col = dropout(inputs_col, self.keep_prob)
  49. relation = torch.diag(self.relation[relation_index])
  50. intermediate_product = torch.mm(inputs_row, relation)
  51. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  52. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  53. rec = torch.flatten(rec)
  54. return self.activation(rec)
  55. class BilinearDecoder(torch.nn.Module):
  56. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  57. def __init__(self, input_dim, num_relation_types, keep_prob=1.,
  58. activation=torch.sigmoid, **kwargs):
  59. super().__init__(**kwargs)
  60. self.input_dim = input_dim
  61. self.num_relation_types = num_relation_types
  62. self.keep_prob = keep_prob
  63. self.activation = activation
  64. self.relation = torch.nn.ParameterList([
  65. torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
  66. for _ in range(num_relation_types)
  67. ])
  68. def forward(self, inputs_row, inputs_col, relation_index):
  69. inputs_row = dropout(inputs_row, self.keep_prob)
  70. inputs_col = dropout(inputs_col, self.keep_prob)
  71. intermediate_product = torch.mm(inputs_row, self.relation[relation_index])
  72. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  73. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  74. rec = torch.flatten(rec)
  75. return self.activation(rec)
  76. class InnerProductDecoder(torch.nn.Module):
  77. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  78. def __init__(self, input_dim, num_relation_types, keep_prob=1.,
  79. activation=torch.sigmoid, **kwargs):
  80. super().__init__(**kwargs)
  81. self.input_dim = input_dim
  82. self.num_relation_types = num_relation_types
  83. self.keep_prob = keep_prob
  84. self.activation = activation
  85. def forward(self, inputs_row, inputs_col, _):
  86. inputs_row = dropout(inputs_row, self.keep_prob)
  87. inputs_col = dropout(inputs_col, self.keep_prob)
  88. rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
  89. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  90. rec = torch.flatten(rec)
  91. return self.activation(rec)