IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

132 lignes
5.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from ..weights import init_glorot
  7. from ..dropout import dropout
  8. class DEDICOMDecoder(torch.nn.Module):
  9. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  10. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  11. activation=torch.sigmoid, **kwargs):
  12. super().__init__(**kwargs)
  13. self.input_dim = input_dim
  14. self.num_relation_types = num_relation_types
  15. self.drop_prob = drop_prob
  16. self.activation = activation
  17. self.global_interaction = init_glorot(input_dim, input_dim)
  18. self.local_variation = [
  19. torch.flatten(init_glorot(input_dim, 1)) \
  20. for _ in range(num_relation_types)
  21. ]
  22. def forward(self, inputs_row, inputs_col):
  23. outputs = []
  24. for k in range(self.num_relation_types):
  25. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  26. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  27. relation = torch.diag(self.local_variation[k])
  28. product1 = torch.mm(inputs_row, relation)
  29. product2 = torch.mm(product1, self.global_interaction)
  30. product3 = torch.mm(product2, relation)
  31. rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
  32. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  33. rec = torch.flatten(rec)
  34. outputs.append(self.activation(rec))
  35. return outputs
  36. class DistMultDecoder(torch.nn.Module):
  37. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  38. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  39. activation=torch.sigmoid, **kwargs):
  40. super().__init__(**kwargs)
  41. self.input_dim = input_dim
  42. self.num_relation_types = num_relation_types
  43. self.drop_prob = drop_prob
  44. self.activation = activation
  45. self.relation = [
  46. torch.flatten(init_glorot(input_dim, 1)) \
  47. for _ in range(num_relation_types)
  48. ]
  49. def forward(self, inputs_row, inputs_col):
  50. outputs = []
  51. for k in range(self.num_relation_types):
  52. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  53. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  54. relation = torch.diag(self.relation[k])
  55. intermediate_product = torch.mm(inputs_row, relation)
  56. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  57. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  58. rec = torch.flatten(rec)
  59. outputs.append(self.activation(rec))
  60. return outputs
  61. class BilinearDecoder(torch.nn.Module):
  62. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  63. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  64. activation=torch.sigmoid, **kwargs):
  65. super().__init__(**kwargs)
  66. self.input_dim = input_dim
  67. self.num_relation_types = num_relation_types
  68. self.drop_prob = drop_prob
  69. self.activation = activation
  70. self.relation = [
  71. init_glorot(input_dim, input_dim) \
  72. for _ in range(num_relation_types)
  73. ]
  74. def forward(self, inputs_row, inputs_col):
  75. outputs = []
  76. for k in range(self.num_relation_types):
  77. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  78. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  79. intermediate_product = torch.mm(inputs_row, self.relation[k])
  80. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  81. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  82. rec = torch.flatten(rec)
  83. outputs.append(self.activation(rec))
  84. return outputs
  85. class InnerProductDecoder(torch.nn.Module):
  86. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  87. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  88. activation=torch.sigmoid, **kwargs):
  89. super().__init__(**kwargs)
  90. self.input_dim = input_dim
  91. self.num_relation_types = num_relation_types
  92. self.drop_prob = drop_prob
  93. self.activation = activation
  94. def forward(self, inputs_row, inputs_col):
  95. outputs = []
  96. for k in range(self.num_relation_types):
  97. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  98. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  99. rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
  100. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  101. rec = torch.flatten(rec)
  102. outputs.append(self.activation(rec))
  103. return outputs