IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.4KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from ..weights import init_glorot
  7. from ..dropout import dropout
  8. class DEDICOMDecoder(torch.nn.Module):
  9. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  10. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  11. activation=torch.sigmoid, **kwargs):
  12. super().__init__(**kwargs)
  13. self.input_dim = input_dim
  14. self.num_relation_types = num_relation_types
  15. self.drop_prob = drop_prob
  16. self.activation = activation
  17. self.global_interaction = init_glorot(input_dim, input_dim)
  18. self.local_variation = [
  19. torch.flatten(init_glorot(input_dim, 1)) \
  20. for _ in range(num_relation_types)
  21. ]
  22. def forward(self, inputs_row, inputs_col):
  23. outputs = []
  24. for k in range(self.num_relation_types):
  25. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  26. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  27. relation = torch.diag(self.local_variation[k])
  28. product1 = torch.mm(inputs_row, relation)
  29. product2 = torch.mm(product1, self.global_interaction)
  30. product3 = torch.mm(product2, relation)
  31. rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
  32. outputs.append(self.activation(rec))
  33. return outputs
  34. class DistMultDecoder(torch.nn.Module):
  35. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  36. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  37. activation=torch.sigmoid, **kwargs):
  38. super().__init__(**kwargs)
  39. self.input_dim = input_dim
  40. self.num_relation_types = num_relation_types
  41. self.drop_prob = drop_prob
  42. self.activation = activation
  43. self.relation = [
  44. torch.flatten(init_glorot(input_dim, 1)) \
  45. for _ in range(num_relation_types)
  46. ]
  47. def forward(self, inputs_row, inputs_col):
  48. outputs = []
  49. for k in range(self.num_relation_types):
  50. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  51. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  52. relation = torch.diag(self.relation[k])
  53. intermediate_product = torch.mm(inputs_row, relation)
  54. rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
  55. outputs.append(self.activation(rec))
  56. return outputs
  57. class BilinearDecoder(torch.nn.Module):
  58. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  59. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  60. activation=torch.sigmoid, **kwargs):
  61. super().__init__(**kwargs)
  62. self.input_dim = input_dim
  63. self.num_relation_types = num_relation_types
  64. self.drop_prob = drop_prob
  65. self.activation = activation
  66. self.relation = [
  67. init_glorot(input_dim, input_dim) \
  68. for _ in range(num_relation_types)
  69. ]
  70. def forward(self, inputs_row, inputs_col):
  71. outputs = []
  72. for k in range(self.num_relation_types):
  73. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  74. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  75. intermediate_product = torch.mm(inputs_row, self.relation[k])
  76. rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
  77. outputs.append(self.activation(rec))
  78. return outputs
  79. class InnerProductDecoder(torch.nn.Module):
  80. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  81. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  82. activation=torch.sigmoid, **kwargs):
  83. super().__init__(**kwargs)
  84. self.input_dim = input_dim
  85. self.num_relation_types = num_relation_types
  86. self.drop_prob = drop_prob
  87. self.activation = activation
  88. def forward(self, inputs_row, inputs_col):
  89. outputs = []
  90. for k in range(self.num_relation_types):
  91. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  92. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  93. rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1))
  94. outputs.append(self.activation(rec))
  95. return outputs